Files
dns-server/threat/engine.go
T
Alex Yang efebce3c39 whois
2026-04-01 12:22:55 +08:00

124 lines
3.1 KiB
Go

package threat
import (
"dns-server/config"
"sync"
"time"
"dns-server/logger"
)
// ThreatEngine 威胁检测引擎
type ThreatEngine struct {
config *config.ThreatConfig
alertMgr *AlertManager
dbManager *ThreatDatabaseManager
clientStats map[string]*ClientQueryStats
mutex sync.RWMutex
}
// NewThreatEngine 创建威胁检测引擎
func NewThreatEngine(config *config.ThreatConfig, alertMgr *AlertManager, dbManager *ThreatDatabaseManager) *ThreatEngine {
return &ThreatEngine{
config: config,
alertMgr: alertMgr,
dbManager: dbManager,
clientStats: make(map[string]*ClientQueryStats),
}
}
// CheckQuery 检查DNS查询是否存在威胁
func (e *ThreatEngine) CheckQuery(sourceIP, domain, queryType string) []*ThreatAlert {
var alerts []*ThreatAlert
logger.Debug("威胁检测引擎检查域名", "domain", domain, "sourceIP", sourceIP)
// 检查威胁域名数据库
if alert := e.checkThreatDatabase(sourceIP, domain); alert != nil {
alerts = append(alerts, alert)
}
// 更新客户端查询统计
e.updateClientStats(sourceIP, queryType)
return alerts
}
// checkThreatDatabase 检查威胁域名数据库
func (e *ThreatEngine) checkThreatDatabase(sourceIP, domain string) *ThreatAlert {
if e.dbManager == nil {
return nil
}
// 获取域名的威胁信息
threatInfo := e.dbManager.GetThreatInfo(domain)
if threatInfo == nil {
return nil
}
logger.Info("检测到威胁域名", "domain", domain, "type", threatInfo.Type, "name", threatInfo.Name, "riskLevel", threatInfo.RiskLevel)
// 根据风险等级确定告警级别
var alertLevel string
switch threatInfo.RiskLevel {
case 3:
alertLevel = AlertLevelHigh
case 2:
alertLevel = AlertLevelMedium
default:
alertLevel = AlertLevelLow
}
return &ThreatAlert{
ID: generateAlertID(),
Timestamp: time.Now(),
Level: alertLevel,
Type: AlertTypeSuspiciousDomain,
Description: "威胁域名数据库匹配",
Details: "威胁类型: " + threatInfo.Type + ", 威胁名称: " + threatInfo.Name,
SourceIP: sourceIP,
Domain: domain,
QueryType: "",
Resolved: false,
}
}
// updateClientStats 更新客户端查询统计
func (e *ThreatEngine) updateClientStats(sourceIP, queryType string) {
e.mutex.Lock()
defer e.mutex.Unlock()
stats, exists := e.clientStats[sourceIP]
if !exists {
stats = &ClientQueryStats{
QueryCount: 1,
NXDomainCount: 0,
LastQueryTime: time.Now(),
QueryTypes: make(map[string]int),
}
e.clientStats[sourceIP] = stats
} else {
stats.QueryCount++
stats.LastQueryTime = time.Now()
}
// 更新查询类型统计
stats.QueryTypes[queryType]++
}
// generateAlertID 生成告警ID
func generateAlertID() string {
return time.Now().Format("20060102150405") + "-" + randomString(8)
}
// randomString 生成随机字符串
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
time.Sleep(1 * time.Nanosecond)
}
return string(result)
}