diff --git a/cmd/client.go b/cmd/client.go index d949050..545bde0 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -7,7 +7,13 @@ import ( "context" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "crypto/rand" + "crypto/rsa" + "crypto/ecdsa" "encoding/json" + "encoding/pem" + "encoding/base64" "flag" "fmt" "io/ioutil" @@ -22,10 +28,12 @@ import ( "syscall" "time" "os/exec" + "bytes" "github.com/go-co-op/gocron" "github.com/gorilla/websocket" "golang.org/x/sys/unix" + "github.com/fullsailor/pkcs7" ) type WsConn struct { @@ -55,6 +63,7 @@ const ( var ( cloudDiscoveryHost = "https://discovery.open-lan.org/v1/devices/" + estHost = "est.certificates.open-lan.org" ) var ( @@ -112,6 +121,7 @@ var ( PublicIpLookup = "ifconfig.me" VlanStatsLast = map[string]InterfaceCounter{} PortStatsLast = map[string]OLSInterfaceCounter{} + estServerList = []string{} ) func sendMessageToController() error { @@ -567,13 +577,378 @@ func getControllerUrl() string { return ControllerUrl } +func getEstServer(domain string) []string { + CAA := []string{} + + // Execute the dig command to retrieve CAA records + cmd := exec.Command("dig", "+short", "caa", domain) + var out bytes.Buffer + cmd.Stdout = &out + err := cmd.Run() + if err != nil { + logger.Error("command dig failed, err is %s", err.Error()) + return CAA + } + + output := strings.TrimSpace(out.String()) + if output == "" { + logger.Error("%s no CAA record", domain) + return CAA + } + + // Split each row and extract the third field + lines := strings.Split(output, "\n") + var thirdFields []string + + for _, line := range lines { + fields := strings.Fields(line) + if len(fields) >= 3 { + thirdField := fields[2] + thirdField = strings.Trim(thirdField, `"`) + thirdFields = append(thirdFields, thirdField) + } + } + + for _, field := range thirdFields { + CAA = append(CAA, field) + } + + logger.Info("CAA: %v", CAA) + return CAA +} + +// Obtain signature algorithm based on private key type +func getSignatureAlgorithm(privateKey interface{}) x509.SignatureAlgorithm { + switch privateKey.(type) { + case *rsa.PrivateKey: + return x509.SHA256WithRSA + case *ecdsa.PrivateKey: + return x509.ECDSAWithSHA256 + default: + return x509.UnknownSignatureAlgorithm + } +} + +func getOperationalCA(estServer string) bool { + client, err := tlsclient(true) + if err != nil { + logger.Error("tls client created failed, err is %s", err.Error()) + return false + } + + caUrl := "https://" + estServer + "/cacerts" + resp, err := client.Get(caUrl) + if err != nil { + logger.Error("request failed, err is %s", err.Error()) + return false + } + defer resp.Body.Close() + + logger.Info("resp.StatusCode %d", resp.StatusCode) + + calist, err := ioutil.ReadAll(resp.Body) + if err != nil { + logger.Error("read response failed, err is %s", err.Error()) + return false + } + + decoded, err := base64.StdEncoding.DecodeString(string(calist)) + if err != nil { + logger.Info("Decode String failed, err is %s", err.Error()) + return false + } + + certPEM := decoded + + // parse PKCS#7 data + p7, err := pkcs7.Parse(certPEM) + if err != nil { + logger.Info("parse PKCS7 failed, err is %s", err.Error()) + return false + } + + var certs []*x509.Certificate + if p7.Certificates != nil { + certs = p7.Certificates + } + + logger.Info("Converted P7 to PEM") + + if len(certs) == 0 { + logger.Info("cannot find operational.ca from response") + return false + } + + // save to operational.ca + file, err := os.Create(operationalCAPath) + if err != nil { + logger.Error("cannot create file operational.ca, err is %s", err.Error()) + return false + } + defer file.Close() + + for _, cert := range certs { + err = pem.Encode(file, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err != nil { + logger.Error("cannot write CA to file operational.ca, err is %s", err.Error()) + return false + } + } + + logger.Info("Persistently stored operational.ca") + return true +} + +func getOperationalCert(estServer string, reenroll bool) bool { + result := false + // Read existing private key file (PEM format) + logger.Info("start to get operational.pem") + privateKeyPEM, err := os.ReadFile(keyPath) + if err != nil { + logger.Error("read %s failed, err is %s", keyPath, err.Error()) + return result + } + + block, _ := pem.Decode(privateKeyPEM) + if block == nil { + logger.Error("Invalid PEM format private key") + return result + } + + // parse private key + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + logger.Error("parse private key failed, err is %s", err.Error()) + return result + } + + pemData := []byte{} + if !reenroll { + pemData, err = ioutil.ReadFile(certPath) + if err != nil { + logger.Error("read %s failed, err is %s", certPath, err.Error()) + return result + } + } else { + pemData, err = ioutil.ReadFile(operationalPath) + if err != nil { + logger.Error("read %s failed, err is %s", operationalPath, err.Error()) + return result + } + } + + blockCert, _ := pem.Decode(pemData) + if blockCert == nil || blockCert.Type != "CERTIFICATE" { + logger.Error("Invalid PEM format cert key") + return result + } + + cert, err := x509.ParseCertificate(blockCert.Bytes) + if err != nil { + logger.Error("parse cert key failed, err is %s", err.Error()) + return result + } + + subjectCert := cert.Subject + + subject := pkix.Name{ + CommonName: subjectCert.CommonName, + Organization: subjectCert.Organization, + } + + // create CSR template + csrTemplate := x509.CertificateRequest{ + Subject: subject, + SignatureAlgorithm: getSignatureAlgorithm(privateKey), + } + + // create CSR + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, privateKey) + if err != nil { + logger.Error("create CSR failed, err is %s", err.Error()) + return result + } + + logger.Info("Generated CSR") + + encoded := base64.StdEncoding.EncodeToString(csrBytes) + reader := strings.NewReader(encoded) + + caURL := "https://" + estServer + "/simpleenroll" + if reenroll { + caURL = "https://" + estServer + "/simplereenroll" + } + + req, _ := http.NewRequest("POST", caURL, reader) + req.Header.Set("Content-Type", "application/pkcs10-base64") + req.Header.Set("Accept", "application/pkcs7") + + client, err := tlsclient(true) + if err != nil { + logger.Error("tls client created failed, err is %s", err.Error()) + return result + } + + resp, err := client.Do(req) + if err != nil { + logger.Error("request failed, err is %s", err.Error()) + return result + } + + logger.Info("EST succeeded") + + defer resp.Body.Close() + + + certPEM, err := ioutil.ReadAll(resp.Body) + if err != nil { + logger.Error("read response failed, err is %s", err.Error()) + return result + } + + decoded, err := base64.StdEncoding.DecodeString(string(certPEM)) + if err != nil { + logger.Error("Decode String failed, err is %s", err.Error()) + return result + } + + certPEM = decoded + + // parse PKCS#7 data + p7, err := pkcs7.Parse(certPEM) + if err != nil { + logger.Error("parse PKCS7 failed, err is %s", err.Error()) + return result + } + + var certs []*x509.Certificate + if p7.Certificates != nil { + certs = p7.Certificates + } + + logger.Info("Converted P7 to PEM") + + if len(certs) == 0 { + logger.Error("cannot find operational certificate from response") + return result + } + + // save to operational.pem + file, err := os.Create(operationalPath) + if err != nil { + logger.Error("cannot create opreational.pem, err is %s", err.Error()) + return result + } + defer file.Close() + + for _, cert := range certs { + err = pem.Encode(file, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err != nil { + logger.Error("write opreational.pem failed, err is %s", err.Error()) + return result + } + } + + logger.Info("Persistently stored operational.pem") + result = true + return result +} + +func operationalExpireCheck() bool { + certPEM, err := ioutil.ReadFile(operationalPath) + if err != nil { + logger.Error("read opreational.pem failed, err is %s", err.Error()) + return true + } + + block, _ := pem.Decode(certPEM) + if block == nil { + logger.Error("decode opreational.pem failed, err is %s", err.Error()) + return true + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + logger.Error("decode opreational.pem failed, err is %s", err.Error()) + return true + } + + now := time.Now() + notBefore := cert.NotBefore + notAfter := cert.NotAfter + + totalDuration := notAfter.Sub(notBefore) + twoThirdsTime := notBefore.Add(totalDuration * 2 / 3) + + logger.Info("Certificate issuance time: %v", notBefore) + logger.Info("Certificate expiration date: %v", notAfter) + + if now.After(twoThirdsTime) || now.Equal(twoThirdsTime) { + logger.Info("The certificate has reached or exceeded 2/3 of its usage period") + return true + } else { + logger.Info("Certificate status is normal") + } + + return false +} + +func setQAorProduct() bool { + result := false + pemData, err := ioutil.ReadFile(certPath) + if err != nil { + logger.Error("read %s failed, err is %s", certPath, err.Error()) + return result + } + + blockCert, _ := pem.Decode(pemData) + if blockCert == nil || blockCert.Type != "CERTIFICATE" { + logger.Error("Invalid PEM format cert key") + return result + } + + cert, err := x509.ParseCertificate(blockCert.Bytes) + if err != nil { + logger.Error("parse cert key failed, err is %s", err.Error()) + return result + } + + issue := cert.Issuer.String() + logger.Info("(Issuer): %s\n", issue) + if strings.Contains(issue, "OpenLAN Demo") { + estHost = "qaest.certificates.open-lan.org:8001" + cloudDiscoveryHost = "https://discovery-qa.open-lan.org/v1/devices/" + } + + return true +} + func firstContact() bool { + resEnv := setQAorProduct() + if !resEnv { + return true + } ControllerAddr = getControllerUrl() if ControllerAddr == "" { logger.Error("Could not get ControllerAddr") + estServerList = []string{estHost} + updateOperationalPem() return true } + estServerList = getEstServer(ControllerAddr) + estServerList = append(estServerList, estHost) + + res := updateOperationalPem() + if !res { + return true + } // set Cloud controller FQDN to connection instance, set it in redis serverInfo := map[string]interface{}{ @@ -589,6 +964,53 @@ func firstContact() bool { return false } +func updateOperationalPem() bool { + for _, oneServer := range estServerList { + res := getOperationalCA(oneServer) + if res { + break + } + } + hasOperationalPem := false + _, err := ioutil.ReadFile(operationalPath) + if err != nil { + logger.Warn("Reading %s failed, err is %s", operationalPath, err.Error()) + for _, oneServer := range estServerList { + res := getOperationalCert(oneServer, false) + if res{ + hasOperationalPem = true + break + } + } + } else { + isExpire := operationalExpireCheck() + if isExpire { + for _, oneServer := range estServerList { + res := getOperationalCert(oneServer, true) + if res{ + hasOperationalPem = true + break + } + } + } + hasOperationalPem = true + } + + return hasOperationalPem +} + +func periodUpdateOpertaional() { + ticker := time.NewTicker(60 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + updateOperationalPem() + } + } +} + func main() { flag.Parse() log.SetFlags(0) @@ -648,6 +1070,8 @@ func main() { break } + go periodUpdateOpertaional() + // Start the main event loop in a goroutine go startEventLoop()