126 lines
3.4 KiB
Go
126 lines
3.4 KiB
Go
package threat
|
|
|
|
import (
|
|
"dns-server/config"
|
|
"sync"
|
|
"time"
|
|
|
|
"dns-server/logger"
|
|
)
|
|
|
|
// ThreatEngine 威胁检测引擎
|
|
type ThreatEngine struct {
|
|
config *config.ThreatConfig
|
|
alertMgr *AlertManager
|
|
dbManager *ThreatDatabaseManager
|
|
clientStats map[string]*ClientQueryStats
|
|
mutex sync.RWMutex
|
|
}
|
|
|
|
// NewThreatEngine 创建威胁检测引擎
|
|
func NewThreatEngine(config *config.ThreatConfig, alertMgr *AlertManager, dbManager *ThreatDatabaseManager) *ThreatEngine {
|
|
return &ThreatEngine{
|
|
config: config,
|
|
alertMgr: alertMgr,
|
|
dbManager: dbManager,
|
|
clientStats: make(map[string]*ClientQueryStats),
|
|
}
|
|
}
|
|
|
|
// CheckQuery 检查DNS查询是否存在威胁
|
|
func (e *ThreatEngine) CheckQuery(sourceIP, domain, queryType string) []*ThreatAlert {
|
|
var alerts []*ThreatAlert
|
|
|
|
logger.Debug("威胁检测引擎检查域名", "domain", domain, "sourceIP", sourceIP)
|
|
|
|
// 检查威胁域名数据库
|
|
if alert := e.checkThreatDatabase(sourceIP, domain); alert != nil {
|
|
alerts = append(alerts, alert)
|
|
}
|
|
|
|
// 更新客户端查询统计
|
|
e.updateClientStats(sourceIP, queryType)
|
|
|
|
return alerts
|
|
}
|
|
|
|
// checkThreatDatabase 检查威胁域名数据库
|
|
func (e *ThreatEngine) checkThreatDatabase(sourceIP, domain string) *ThreatAlert {
|
|
if e.dbManager == nil {
|
|
return nil
|
|
}
|
|
|
|
// 获取域名的威胁信息
|
|
threatInfo := e.dbManager.GetThreatInfo(domain)
|
|
if threatInfo == nil {
|
|
return nil
|
|
}
|
|
|
|
logger.Info("检测到威胁域名", "domain", domain, "type", threatInfo.Type, "name", threatInfo.Name, "riskLevel", threatInfo.RiskLevel)
|
|
|
|
// 根据风险等级确定告警级别(数据库:1=高,2=中,3=低)
|
|
var alertLevel string
|
|
switch threatInfo.RiskLevel {
|
|
case 1:
|
|
alertLevel = AlertLevelHigh // 高风险
|
|
case 2:
|
|
alertLevel = AlertLevelMedium // 中风险
|
|
case 3:
|
|
alertLevel = AlertLevelLow // 低风险
|
|
default:
|
|
alertLevel = AlertLevelLow
|
|
}
|
|
|
|
return &ThreatAlert{
|
|
ID: generateAlertID(),
|
|
Timestamp: time.Now(),
|
|
Level: alertLevel,
|
|
Type: threatInfo.Type, // 使用数据库中的 type 列(如:钓鱼网站、仿冒网站)
|
|
Description: threatInfo.Name, // 使用数据库中的 name 列(如:Silver fox 团伙)
|
|
Details: "威胁类型:" + threatInfo.Type + ", 威胁名称:" + threatInfo.Name + ", 风险等级:" + string(rune('0'+threatInfo.RiskLevel)),
|
|
SourceIP: sourceIP,
|
|
Domain: domain,
|
|
QueryType: "",
|
|
Resolved: false,
|
|
}
|
|
}
|
|
|
|
// updateClientStats 更新客户端查询统计
|
|
func (e *ThreatEngine) updateClientStats(sourceIP, queryType string) {
|
|
e.mutex.Lock()
|
|
defer e.mutex.Unlock()
|
|
|
|
stats, exists := e.clientStats[sourceIP]
|
|
if !exists {
|
|
stats = &ClientQueryStats{
|
|
QueryCount: 1,
|
|
NXDomainCount: 0,
|
|
LastQueryTime: time.Now(),
|
|
QueryTypes: make(map[string]int),
|
|
}
|
|
e.clientStats[sourceIP] = stats
|
|
} else {
|
|
stats.QueryCount++
|
|
stats.LastQueryTime = time.Now()
|
|
}
|
|
|
|
// 更新查询类型统计
|
|
stats.QueryTypes[queryType]++
|
|
}
|
|
|
|
// generateAlertID 生成告警ID
|
|
func generateAlertID() string {
|
|
return time.Now().Format("20060102150405") + "-" + randomString(8)
|
|
}
|
|
|
|
// randomString 生成随机字符串
|
|
func randomString(length int) string {
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
result := make([]byte, length)
|
|
for i := range result {
|
|
result[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
|
time.Sleep(1 * time.Nanosecond)
|
|
}
|
|
return string(result)
|
|
}
|