From 6acdcb742932ed2e98152eb2bdc52003505fe275 Mon Sep 17 00:00:00 2001 From: Mohammad Mahdi Date: Thu, 2 Oct 2025 11:14:59 +0330 Subject: [PATCH] Improve tls certificate verification --- internal/mqtt/tls.go | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/internal/mqtt/tls.go b/internal/mqtt/tls.go index 6db5d99..9031527 100644 --- a/internal/mqtt/tls.go +++ b/internal/mqtt/tls.go @@ -4,6 +4,8 @@ import ( "MQTTLogger/config" "crypto/tls" "crypto/x509" + "fmt" + "net/url" "os" "go.uber.org/zap" @@ -24,11 +26,25 @@ func NewTLSConfig(logger *zap.Logger, config *config.Config) *tls.Config { } return &tls.Config{ - RootCAs: certpool, - // We use the provided cert not the one server sends. - ClientAuth: tls.NoClientCert, - ClientCAs: nil, - InsecureSkipVerify: true, // I know - Certificates: nil, + RootCAs: certpool, + InsecureSkipVerify: true, + Certificates: nil, + VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + + opts := x509.VerifyOptions{Roots: certpool} + if _, err := cert.Verify(opts); err != nil { + return fmt.Errorf("failed to verify chain: %w", err) + } + + expectedCN, _ := url.Parse(config.URI) + if cert.Subject.CommonName != expectedCN.Hostname() { + return fmt.Errorf("unexpected CN, expected %s but got %s", expectedCN.Host, cert.Subject.CommonName) + } + return nil + }, } }