Skip to content

Commit

Permalink
fix parameter validation
Browse files Browse the repository at this point in the history
nnothing1 committed Dec 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 714c14e commit 1a864f4
Showing 2 changed files with 39 additions and 27 deletions.
35 changes: 22 additions & 13 deletions consul_registry.go
Original file line number Diff line number Diff line change
@@ -144,17 +144,25 @@ func (c *consulRegistry) Register(info *registry.Info) error {
svcInfo.Check = c.opts.check
}

var ttl time.Duration
if c.opts.check.TTL != "" {
ttl, err = time.ParseDuration(c.opts.check.TTL)
if err != nil {
return err
}
if ttl <= time.Second {
return errors.New("consul check ttl must be greater than one second")
}
}

if err := c.consulClient.Agent().ServiceRegister(svcInfo); err != nil {
return err
}

if c.opts.check.TTL != "" {
if ttl, err := time.ParseDuration(c.opts.check.TTL); err != nil {
return err
} else {
return c.startTTLHeartbeat(ttl)
}
c.startTTLHeartbeat(ttl)
}

return nil
}

@@ -165,18 +173,20 @@ func (c *consulRegistry) Deregister(info *registry.Info) error {
return err
}

err = c.consulClient.Agent().ServiceDeregister(svcID)
if err != nil {
return err
}

if c.cancelUpdateTTL != nil {
defer c.cancelUpdateTTL()
c.cancelUpdateTTL()
}
return c.consulClient.Agent().ServiceDeregister(svcID)

return nil
}

// startTTLHeartbeat start a goroutine to periodically update TTL.
func (c *consulRegistry) startTTLHeartbeat(ttl time.Duration) error {
if ttl <= 1*time.Second {
return errors.New("consul check ttl must be greater than one second")
}

func (c *consulRegistry) startTTLHeartbeat(ttl time.Duration) {
ctx, cancel := context.WithCancel(context.Background())
c.cancelUpdateTTL = cancel
go func() {
@@ -196,7 +206,6 @@ func (c *consulRegistry) startTTLHeartbeat(ttl time.Duration) error {
}
}
}()
return nil
}

func validateRegistryInfo(info *registry.Info) error {
31 changes: 17 additions & 14 deletions consul_test.go
Original file line number Diff line number Diff line change
@@ -220,21 +220,24 @@ func TestRegisterWithTTLCheck(t *testing.T) {
err = cRegistryWithTTL.Register(info)
assert.Nil(t, err)
// wait for health check passing
time.Sleep(time.Second * 6)

list, _, err := consulClient.Health().Service(testSvcName, "", true, nil)
assert.Nil(t, err)
if assert.Equal(t, 1, len(list)) {
ss := list[0]
gotSvc := ss.Service
assert.Equal(t, testSvcName, gotSvc.Service)
assert.Equal(t, testSvcAddr.String(), fmt.Sprintf("%s:%d", gotSvc.Address, gotSvc.Port))
assert.Equal(t, testSvcWeight, gotSvc.Weights.Passing)
assert.Equal(t, len(tagMap), len(gotSvc.Tags))
for _, tag := range gotSvc.Tags {
kv := strings.Split(tag, ":")
k, v := kv[0], kv[1]
assert.Equal(t, tagMap[k], v)
var list []*consulapi.ServiceEntry
for i := 0; i < 3; i++ {
time.Sleep(time.Second * 5)
list, _, err = consulClient.Health().Service(testSvcName, "", true, nil)
assert.Nil(t, err)
if assert.Equal(t, 1, len(list)) {
ss := list[0]
gotSvc := ss.Service
assert.Equal(t, testSvcName, gotSvc.Service)
assert.Equal(t, testSvcAddr.String(), fmt.Sprintf("%s:%d", gotSvc.Address, gotSvc.Port))
assert.Equal(t, testSvcWeight, gotSvc.Weights.Passing)
assert.Equal(t, len(tagMap), len(gotSvc.Tags))
for _, tag := range gotSvc.Tags {
kv := strings.Split(tag, ":")
k, v := kv[0], kv[1]
assert.Equal(t, tagMap[k], v)
}
}
}
}

0 comments on commit 1a864f4

Please sign in to comment.