Skip to content

Commit

Permalink
feat: edit server config online (#162)
Browse files Browse the repository at this point in the history
* feat: edit server config online

* change back to 30s

* reset geoip state

* use json

* refactor

* fix state
  • Loading branch information
uubulb authored Jan 31, 2025
1 parent 033b4ee commit d88ed03
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 45 deletions.
110 changes: 94 additions & 16 deletions cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ var (
lastReportHostInfo time.Time
lastReportIPInfo time.Time

hostStatus atomic.Bool
ipStatus atomic.Bool
hostStatus atomic.Bool
ipStatus atomic.Bool
reloadStatus atomic.Bool

dnsResolver = &net.Resolver{PreferGo: true}
httpClient = &http.Client{
Expand All @@ -76,11 +77,13 @@ var (
Timeout: time.Second * 30,
Transport: &http3.RoundTripper{},
}

reloadSigChan = make(chan struct{})
)

var (
println = logger.DefaultLogger.Println
printf = logger.DefaultLogger.Printf
println = logger.Println
printf = logger.Printf
)

const (
Expand Down Expand Up @@ -271,7 +274,6 @@ func run() {

retry := func() {
initialized = false
println("Error to close connection ...")
if conn != nil {
conn.Close()
}
Expand Down Expand Up @@ -313,31 +315,41 @@ func run() {

errCh := make(chan error)

wCtx, wCancel := context.WithCancel(context.Background())
// 执行 Task
tasks, err := client.RequestTask(context.Background())
tasks, err := client.RequestTask(wCtx)
if err != nil {
printf("请求任务失败: %v", err)
retry()
continue
}
go receiveTasksDaemon(tasks, errCh)

reportState, err := client.ReportSystemState(context.Background())
reportState, err := client.ReportSystemState(wCtx)
if err != nil {
printf("上报状态信息失败: %v", err)
retry()
continue
}
go reportStateDaemon(reportState, errCh)

for i := 0; i < 2; i++ {
err = <-errCh
if i == 0 {
tasks.CloseSend()
reportState.CloseSend()
for i := 0; i < 2; {
select {
case <-reloadSigChan:
println("Reloading...")
wCancel()
case err := <-errCh:
if i == 0 {
tasks.CloseSend()
reportState.CloseSend()
println("Error to close connection ...")
}
i++
printf("worker exit to main: %v", err)
}
printf("worker exit to main: %v", err)
}

wCancel()
close(errCh)

retry()
Expand Down Expand Up @@ -403,7 +415,7 @@ func runService(action string, path string) {

err = s.Run()
if err != nil {
logger.DefaultLogger.Error(err)
logger.Error(err)
}
}

Expand Down Expand Up @@ -456,6 +468,10 @@ func doTask(task *pb.Task) *pb.TaskResult {
case model.TaskTypeFM:
handleFMTask(task)
return nil
case model.TaskTypeReportConfig:
handleReportConfigTask(&result)
case model.TaskTypeApplyConfig:
handleApplyConfigTask(task)
case model.TaskTypeKeepalive:
default:
printf("不支持的任务: %v", task)
Expand Down Expand Up @@ -798,6 +814,68 @@ func handleCommandTask(task *pb.Task, result *pb.TaskResult) {
result.Delay = float32(time.Since(startedAt).Seconds())
}

func handleReportConfigTask(result *pb.TaskResult) {
if agentConfig.DisableCommandExecute {
result.Data = "此 Agent 已禁止命令执行"
return
}

if reloadStatus.Load() {
result.Data = "another reload is in process"
return
}

println("Executing Report Config Task")

c, err := util.Json.Marshal(agentConfig)
if err != nil {
result.Data = err.Error()
return
}

result.Data = string(c)
result.Successful = true
}

func handleApplyConfigTask(task *pb.Task) {
if agentConfig.DisableCommandExecute {
return
}

if !reloadStatus.CompareAndSwap(false, true) {
return
}

println("Executing Apply Config Task")

var tmpConfig model.AgentConfig
json := []byte(task.GetData())
if err := util.Json.Unmarshal(json, &tmpConfig); err != nil {
printf("Validate Config failed: %v", err)
reloadStatus.Store(false)
return
}

if err := model.ValidateConfig(&tmpConfig, true); err != nil {
printf("Validate Config failed: %v", err)
reloadStatus.Store(false)
return
}

println("Will reload workers in 30 seconds")
time.AfterFunc(time.Second*30, func() {
println("Applying new configuration...")
agentConfig.Apply(&tmpConfig)
agentConfig.Save()
geoipReported = false
logger.SetEnable(agentConfig.Debug)
monitor.InitConfig(&agentConfig)
monitor.CustomEndpoints = agentConfig.CustomIPApi
reloadStatus.Store(false)
reloadSigChan <- struct{}{}
})
}

type WindowSize struct {
Cols uint32
Rows uint32
Expand Down Expand Up @@ -1022,11 +1100,11 @@ func ioStreamKeepAlive(ctx context.Context, stream pb.NezhaService_IOStreamClien
for {
select {
case <-ctx.Done():
log.Printf("IOStream KeepAlive stopped: %v", ctx.Err())
printf("IOStream KeepAlive stopped: %v", ctx.Err())
return
case <-ticker.C:
if err := stream.Send(&pb.IOStreamData{Data: []byte{}}); err != nil {
log.Printf("IOStream KeepAlive failed: %v", err)
printf("IOStream KeepAlive failed: %v", err)
return
}
}
Expand Down
74 changes: 51 additions & 23 deletions model/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,6 @@ func (c *AgentConfig) Read(path string) error {
return err
}

if c.ReportDelay == 0 {
c.ReportDelay = 3
}

if c.IPReportPeriod == 0 {
c.IPReportPeriod = 1800
} else if c.IPReportPeriod < 30 {
c.IPReportPeriod = 30
}

if c.Server == "" {
return errors.New("server address should not be empty")
}

if c.ClientSecret == "" {
return errors.New("client_secret must be specified")
}

if c.ReportDelay < 1 || c.ReportDelay > 4 {
return errors.New("report-delay ranges from 1-4")
}

if c.UUID == "" {
if uuid, err := uuid.GenerateUUID(); err == nil {
c.UUID = uuid
Expand All @@ -106,7 +84,7 @@ func (c *AgentConfig) Read(path string) error {
}
}

return nil
return ValidateConfig(c, false)
}

func (c *AgentConfig) Save() error {
Expand All @@ -123,6 +101,56 @@ func (c *AgentConfig) Save() error {
return os.WriteFile(c.filePath, data, 0600)
}

func (c *AgentConfig) Apply(new *AgentConfig) {
c.Debug = new.Debug
c.HardDrivePartitionAllowlist = new.HardDrivePartitionAllowlist
c.NICAllowlist = new.NICAllowlist
c.GPU = new.GPU
c.Temperature = new.Temperature
c.SkipConnectionCount = new.SkipConnectionCount
c.SkipProcsCount = new.SkipProcsCount
c.DisableAutoUpdate = new.DisableAutoUpdate
c.DisableForceUpdate = new.DisableForceUpdate
c.DisableCommandExecute = new.DisableCommandExecute
c.ReportDelay = new.ReportDelay
c.DisableNat = new.DisableNat
c.DisableSendQuery = new.DisableSendQuery
c.IPReportPeriod = new.IPReportPeriod
c.SelfUpdatePeriod = new.SelfUpdatePeriod
}

func ValidateConfig(c *AgentConfig, isRemoteEdit bool) error {
if c.ReportDelay == 0 {
c.ReportDelay = 3
}

if c.IPReportPeriod == 0 {
c.IPReportPeriod = 1800
} else if c.IPReportPeriod < 30 {
c.IPReportPeriod = 30
}

if c.ReportDelay < 1 || c.ReportDelay > 4 {
return errors.New("report-delay ranges from 1-4")
}

if !isRemoteEdit {
if c.Server == "" {
return errors.New("server address should not be empty")
}

if c.ClientSecret == "" {
return errors.New("client_secret must be specified")
}

if _, err := uuid.ParseUUID(c.UUID); err != nil {
return err
}
}

return nil
}

type kubeyaml struct{}

// Unmarshal parses the given YAML bytes.
Expand Down
2 changes: 2 additions & 0 deletions model/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ const (
TaskTypeNAT
TaskTypeReportHostInfoDeprecated
TaskTypeFM
TaskTypeReportConfig
TaskTypeApplyConfig
)

type TerminalTask struct {
Expand Down
30 changes: 27 additions & 3 deletions pkg/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

var (
DefaultLogger = &ServiceLogger{enabled: true, logger: service.ConsoleLogger}
defaultLogger = &ServiceLogger{enabled: true, logger: service.ConsoleLogger}

loggerOnce sync.Once
)
Expand All @@ -21,18 +21,42 @@ type ServiceLogger struct {

func InitDefaultLogger(enabled bool, logger service.Logger) {
loggerOnce.Do(func() {
DefaultLogger.enabled = enabled
DefaultLogger.logger = logger
defaultLogger.enabled = enabled
defaultLogger.logger = logger
})
}

func SetEnable(enable bool) {
defaultLogger.SetEnable(enable)
}

func Println(v ...interface{}) {
defaultLogger.Println(v...)
}

func Printf(format string, v ...interface{}) {
defaultLogger.Printf(format, v...)
}

func Error(v ...interface{}) error {
return defaultLogger.Error(v...)
}

func Errorf(format string, v ...interface{}) error {
return defaultLogger.Errorf(format, v...)
}

func NewServiceLogger(enable bool, logger service.Logger) *ServiceLogger {
return &ServiceLogger{
enabled: enable,
logger: logger,
}
}

func (s *ServiceLogger) SetEnable(enable bool) {
s.enabled = enable
}

func (s *ServiceLogger) Println(v ...interface{}) {
if s.enabled {
s.logger.Infof("NEZHA@%s>> %v", time.Now().Format("2006-01-02 15:04:05"), fmt.Sprint(v...))
Expand Down
2 changes: 1 addition & 1 deletion pkg/monitor/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var (
Version string
agentConfig *model.AgentConfig

printf = logger.DefaultLogger.Printf
printf = logger.Printf
)

var (
Expand Down
4 changes: 2 additions & 2 deletions pkg/monitor/myip.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ var (

// UpdateIP 按设置时间间隔更新IP地址的缓存
func FetchIP(useIPv6CountryCode bool) *pb.GeoIP {
logger.DefaultLogger.Println("正在更新本地缓存IP信息")
logger.Println("正在更新本地缓存IP信息")

if retryTimes > 2 && time.Now().Before(latestRetryAt.Add(latestRetryAt.Sub(failedStartedAt)*time.Duration(2))) {
logger.DefaultLogger.Println("IP地址获取失败次数过多,fallback到agent连接IP")
logger.Println("IP地址获取失败次数过多,fallback到agent连接IP")
return &pb.GeoIP{
Use6: false,
Ip: &pb.IP{
Expand Down

0 comments on commit d88ed03

Please sign in to comment.