package threat import ( "dns-server/config" "sync" "time" "dns-server/logger" ) // ThreatEngine 威胁检测引擎 type ThreatEngine struct { config *config.ThreatConfig alertMgr *AlertManager dbManager *ThreatDatabaseManager clientStats map[string]*ClientQueryStats mutex sync.RWMutex } // NewThreatEngine 创建威胁检测引擎 func NewThreatEngine(config *config.ThreatConfig, alertMgr *AlertManager, dbManager *ThreatDatabaseManager) *ThreatEngine { return &ThreatEngine{ config: config, alertMgr: alertMgr, dbManager: dbManager, clientStats: make(map[string]*ClientQueryStats), } } // CheckQuery 检查DNS查询是否存在威胁 func (e *ThreatEngine) CheckQuery(sourceIP, domain, queryType string) []*ThreatAlert { var alerts []*ThreatAlert logger.Debug("威胁检测引擎检查域名", "domain", domain, "sourceIP", sourceIP) // 检查威胁域名数据库 if alert := e.checkThreatDatabase(sourceIP, domain); alert != nil { alerts = append(alerts, alert) } // 更新客户端查询统计 e.updateClientStats(sourceIP, queryType) return alerts } // checkThreatDatabase 检查威胁域名数据库 func (e *ThreatEngine) checkThreatDatabase(sourceIP, domain string) *ThreatAlert { if e.dbManager == nil { return nil } // 获取域名的威胁信息 threatInfo := e.dbManager.GetThreatInfo(domain) if threatInfo == nil { return nil } logger.Info("检测到威胁域名", "domain", domain, "type", threatInfo.Type, "name", threatInfo.Name, "riskLevel", threatInfo.RiskLevel) // 根据风险等级确定告警级别 var alertLevel string switch threatInfo.RiskLevel { case 3: alertLevel = AlertLevelHigh case 2: alertLevel = AlertLevelMedium default: alertLevel = AlertLevelLow } return &ThreatAlert{ ID: generateAlertID(), Timestamp: time.Now(), Level: alertLevel, Type: AlertTypeSuspiciousDomain, Description: "威胁域名数据库匹配", Details: "威胁类型: " + threatInfo.Type + ", 威胁名称: " + threatInfo.Name, SourceIP: sourceIP, Domain: domain, QueryType: "", Resolved: false, } } // updateClientStats 更新客户端查询统计 func (e *ThreatEngine) updateClientStats(sourceIP, queryType string) { e.mutex.Lock() defer e.mutex.Unlock() stats, exists := e.clientStats[sourceIP] if !exists { stats = &ClientQueryStats{ QueryCount: 1, NXDomainCount: 0, LastQueryTime: time.Now(), QueryTypes: make(map[string]int), } e.clientStats[sourceIP] = stats } else { stats.QueryCount++ stats.LastQueryTime = time.Now() } // 更新查询类型统计 stats.QueryTypes[queryType]++ } // generateAlertID 生成告警ID func generateAlertID() string { return time.Now().Format("20060102150405") + "-" + randomString(8) } // randomString 生成随机字符串 func randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, length) for i := range result { result[i] = charset[time.Now().UnixNano()%int64(len(charset))] time.Sleep(1 * time.Nanosecond) } return string(result) }