package threat import ( "dns-server/config" "sync" "time" "dns-server/logger" ) // ThreatEngine 威胁检测引擎 type ThreatEngine struct { config *config.ThreatConfig alertMgr *AlertManager dbManager *ThreatDatabaseManager clientStats map[string]*ClientQueryStats mutex sync.RWMutex } // NewThreatEngine 创建威胁检测引擎 func NewThreatEngine(config *config.ThreatConfig, alertMgr *AlertManager, dbManager *ThreatDatabaseManager) *ThreatEngine { return &ThreatEngine{ config: config, alertMgr: alertMgr, dbManager: dbManager, clientStats: make(map[string]*ClientQueryStats), } } // CheckQuery 检查DNS查询是否存在威胁 func (e *ThreatEngine) CheckQuery(sourceIP, domain, queryType string) []*ThreatAlert { var alerts []*ThreatAlert logger.Debug("威胁检测引擎检查域名", "domain", domain, "sourceIP", sourceIP) // 检查威胁域名数据库 if alert := e.checkThreatDatabase(sourceIP, domain); alert != nil { alerts = append(alerts, alert) } // 更新客户端查询统计 e.updateClientStats(sourceIP, queryType) return alerts } // checkThreatDatabase 检查威胁域名数据库 func (e *ThreatEngine) checkThreatDatabase(sourceIP, domain string) *ThreatAlert { if e.dbManager == nil { return nil } // 获取域名的威胁信息 threatInfo := e.dbManager.GetThreatInfo(domain) if threatInfo == nil { return nil } logger.Info("检测到威胁域名", "domain", domain, "type", threatInfo.Type, "name", threatInfo.Name, "riskLevel", threatInfo.RiskLevel) // 根据风险等级确定告警级别(数据库:1=高,2=中,3=低) var alertLevel string switch threatInfo.RiskLevel { case 1: alertLevel = AlertLevelHigh // 高风险 case 2: alertLevel = AlertLevelMedium // 中风险 case 3: alertLevel = AlertLevelLow // 低风险 default: alertLevel = AlertLevelLow } return &ThreatAlert{ ID: generateAlertID(), Timestamp: time.Now(), Level: alertLevel, Type: threatInfo.Type, // 使用数据库中的 type 列(如:钓鱼网站、仿冒网站) Description: threatInfo.Name, // 使用数据库中的 name 列(如:Silver fox 团伙) Details: "威胁类型:" + threatInfo.Type + ", 威胁名称:" + threatInfo.Name + ", 风险等级:" + string(rune('0'+threatInfo.RiskLevel)), SourceIP: sourceIP, Domain: domain, QueryType: "", Resolved: false, } } // updateClientStats 更新客户端查询统计 func (e *ThreatEngine) updateClientStats(sourceIP, queryType string) { e.mutex.Lock() defer e.mutex.Unlock() stats, exists := e.clientStats[sourceIP] if !exists { stats = &ClientQueryStats{ QueryCount: 1, NXDomainCount: 0, LastQueryTime: time.Now(), QueryTypes: make(map[string]int), } e.clientStats[sourceIP] = stats } else { stats.QueryCount++ stats.LastQueryTime = time.Now() } // 更新查询类型统计 stats.QueryTypes[queryType]++ } // generateAlertID 生成告警ID func generateAlertID() string { return time.Now().Format("20060102150405") + "-" + randomString(8) } // randomString 生成随机字符串 func randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, length) for i := range result { result[i] = charset[time.Now().UnixNano()%int64(len(charset))] time.Sleep(1 * time.Nanosecond) } return string(result) }