Skip to content

Commit

Permalink
add crash protection for GetAllCertificates
Browse files Browse the repository at this point in the history
  • Loading branch information
harshavardhana committed Jan 9, 2025
1 parent 707f4b6 commit d1b1bf6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
17 changes: 17 additions & 0 deletions certs/certs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ func TestNewManager(t *testing.T) {
if err == nil {
t.Fatal("Expected to fail but got success")
}

allCerts := c.GetAllCertificates()
var found bool
for _, cert := range allCerts {
if cert.Issuer.String() != "CN=minio.io,OU=Engineering,O=Minio,L=Redwood City,ST=CA,C=US" {
t.Error("Unexpected cert issuer found")
}
found = true
}
if !found {
t.Error("atleast one public cert is expected")
}

var nilMgr *certs.Manager
if len(nilMgr.GetAllCertificates()) != 0 {
t.Error("no public cert is expected")
}
}

func TestValidPairAfterWrite(t *testing.T) {
Expand Down
36 changes: 36 additions & 0 deletions certs/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ func NewManager(ctx context.Context, certFile, keyFile string, loadX509KeyPair L

// UpdateReloadDuration set custom symlink reload duration
func (m *Manager) UpdateReloadDuration(t time.Duration) {
if m == nil {
return
}

m.lock.Lock()
m.duration = t
m.lock.Unlock()
Expand All @@ -107,6 +111,10 @@ func (m *Manager) UpdateReloadDuration(t time.Duration) {
// If there is already a certificate with the same base name it will be
// replaced by the newly added one.
func (m *Manager) AddCertificate(certFile, keyFile string) (err error) {
if m == nil {
return nil
}

certFile, err = filepath.Abs(certFile)
if err != nil {
return err
Expand Down Expand Up @@ -194,6 +202,10 @@ func (m *Manager) reloader() <-chan struct{} {
// ReloadOnSignal specifies one or more signals that will trigger certificates reloading.
// If called multiple times with the same signal certificates
func (m *Manager) ReloadOnSignal(sig ...os.Signal) {
if m == nil {
return
}

if len(sig) == 0 {
return
}
Expand All @@ -214,6 +226,10 @@ func (m *Manager) ReloadOnSignal(sig ...os.Signal) {

// ReloadCerts will forcefully reload all certs.
func (m *Manager) ReloadCerts() {
if m == nil {
return
}

m.lock.RLock()
defer m.lock.RUnlock()
for _, ch := range m.reloadCerts {
Expand All @@ -228,6 +244,10 @@ func (m *Manager) ReloadCerts() {
// watchSymlinks starts an endless loop reloading the
// certFile and keyFile periodically.
func (m *Manager) watchSymlinks(watch pair, reload <-chan struct{}) {
if m == nil {
return
}

t := time.NewTimer(m.duration)
defer t.Stop()

Expand Down Expand Up @@ -263,6 +283,10 @@ func (m *Manager) watchSymlinks(watch pair, reload <-chan struct{}) {
// Once an event occurs it reloads the private key and certificate that
// has changed, if any.
func (m *Manager) watchFileEvents(watch pair, events chan notify.EventInfo, reload <-chan struct{}) {
if m == nil {
return
}

for {
select {
case <-m.done:
Expand Down Expand Up @@ -301,6 +325,10 @@ func (m *Manager) watchFileEvents(watch pair, events chan notify.EventInfo, relo
// found GetCertificate returns the certificate loaded from the
// Public file.
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if m == nil {
return nil, errors.New("certs: no server certificate is supported by peer")
}

m.lock.RLock()
defer m.lock.RUnlock()

Expand Down Expand Up @@ -368,6 +396,10 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
// found GetClientCertificate returns the certificate loaded from the
// Public file.
func (m *Manager) GetClientCertificate(reqInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
if m == nil {
return nil, errors.New("certs: no client certificate is supported by peer")
}

m.lock.RLock()
defer m.lock.RUnlock()

Expand Down Expand Up @@ -406,6 +438,10 @@ func isSymlink(file string) (bool, error) {

// GetAllCertificates returns all the certificates loaded
func (m *Manager) GetAllCertificates() []*x509.Certificate {
if m == nil {
return nil
}

m.lock.RLock()
defer m.lock.RUnlock()

Expand Down

0 comments on commit d1b1bf6

Please sign in to comment.