增加威胁域名审计

This commit is contained in:
Alex Yang
2026-04-03 10:04:07 +08:00
parent 170cdb3537
commit f8e222aaf6
41 changed files with 81016 additions and 4672993 deletions
+111 -302
View File
@@ -131,18 +131,9 @@ type Server struct {
monthlyStats map[string]int64 // 按月统计屏蔽数量
// 新日志系统
logManager *log.LogManager // 日志管理器
// 旧日志系统(保留以兼容,但不再使用)
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表(已废弃)
maxQueryLogs int // 最大保存日志数量(已废弃)
logChannel chan QueryLog // 日志处理通道(已废弃)
saveTicker *time.Ticker // 用于定时保存数据(已废弃)
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止(已废弃)
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
logManager *log.LogManager // 日志管理器
archiveManager *log.ArchiveManager // 归档管理器
archiveQueryEngine *log.ArchiveQueryEngine // 归档查询引擎
// DNS查询缓存
DnsCache *DNSCache // DNS响应缓存
@@ -210,7 +201,6 @@ func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager,
},
ctx: ctx,
cancel: cancel,
startTime: time.Now(), // 记录服务器启动时间
stats: &Stats{
Queries: 0,
Blocked: 0,
@@ -232,11 +222,6 @@ func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager,
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
logChannel: make(chan QueryLog, 100000), // 日志处理通道,缓冲区大小 100000,避免日志被丢弃
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// DNS 查询缓存初始化
DnsCache: NewDNSCache(cacheTTL, globalConfig.DNS.CacheMode, globalConfig.DNS.CacheSize, globalConfig.DNS.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL),
@@ -267,6 +252,49 @@ func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager,
logger.Info("新日志系统初始化成功", "ringBufferSize", log.DefaultConfig().RingBufferSize, "databasePath", log.DefaultConfig().DatabasePath)
}
// 初始化归档管理器
if globalConfig.QueryLog.ArchiveEnabled {
// 转换为 log.QueryLogConfig
logConfig := &log.QueryLogConfig{
Enabled: globalConfig.QueryLog.Enabled,
RingBufferSize: globalConfig.QueryLog.RingBufferSize,
DatabasePath: globalConfig.QueryLog.DatabasePath,
MaxDatabaseSizeMB: globalConfig.QueryLog.MaxDatabaseSizeMB,
EnableWAL: globalConfig.QueryLog.EnableWAL,
ArchiveEnabled: globalConfig.QueryLog.ArchiveEnabled,
ArchiveDir: globalConfig.QueryLog.ArchiveDir,
ArchivePrefix: globalConfig.QueryLog.ArchivePrefix,
CompressionLevel: globalConfig.QueryLog.CompressionLevel,
RetentionDays: globalConfig.QueryLog.RetentionDays,
RetentionMonths: globalConfig.QueryLog.RetentionMonths,
QueryTimeout: globalConfig.QueryLog.QueryTimeout,
EnableCache: globalConfig.QueryLog.EnableCache,
CacheTTL: globalConfig.QueryLog.CacheTTL,
}
archiveManager, err := log.NewArchiveManager(logConfig, globalConfig.QueryLog.DatabasePath)
if err != nil {
logger.Error("初始化归档管理器失败", "error", err)
} else {
server.archiveManager = archiveManager
logger.Info("归档管理器初始化成功", "archiveDir", globalConfig.QueryLog.ArchiveDir)
}
// 初始化归档查询引擎
if logManager != nil {
sqliteStore := logManager.GetSQLiteStore()
if sqliteStore != nil {
archiveQueryEngine, err := log.NewArchiveQueryEngine(sqliteStore, archiveManager, logConfig)
if err != nil {
logger.Error("初始化归档查询引擎失败", "error", err)
} else {
server.archiveQueryEngine = archiveQueryEngine
logger.Info("归档查询引擎初始化成功")
}
}
}
}
// 加载已保存的统计数据
server.loadStatsData()
@@ -291,6 +319,11 @@ func (s *Server) initThreatDetection() {
// 加载威胁域名数据库
s.dbManager.LoadDatabase()
// 启动文件监听,自动检测数据库变更
if err := s.dbManager.StartWatching(); err != nil {
logger.Warn("启动威胁域名数据库监听失败", "error", err)
}
// 创建威胁检测引擎
s.threatEngine = threat.NewThreatEngine(threatConfig, s.alertManager, s.dbManager)
}
@@ -303,15 +336,9 @@ func (s *Server) Start() error {
s.cancel = cancel
// 重新初始化saveDone通道
s.saveDone = make(chan struct{})
// 重置stopped标志
s.stoppedMutex.Lock()
s.stopped = false
s.stoppedMutex.Unlock()
// 更新服务器启动时间
s.startTime = time.Now()
// 初始化威胁检测相关组件
s.initThreatDetection()
@@ -341,7 +368,7 @@ func (s *Server) Start() error {
// 启动日志处理协程(已移除,新日志系统使用 SQLite 存储)
// go s.processLogs()
// 启动统计数据定期重置功能(每24小时)
// 启动统计数据定期重置功能(每 24 小时)
go func() {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
@@ -355,6 +382,32 @@ func (s *Server) Start() error {
}
}()
// 启动归档监控和清理任务
if s.archiveManager != nil {
// 启动归档监控
s.archiveManager.StartWatching()
// 启动定期清理任务(每天执行)
go func() {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if deleted, err := s.archiveManager.CleanupOldArchives(); err != nil {
logger.Error("清理归档失败", "error", err)
} else if deleted > 0 {
logger.Info("清理归档完成", "deleted", deleted)
}
case <-s.ctx.Done():
return
}
}
}()
logger.Info("归档监控和清理任务已启动")
}
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
@@ -402,17 +455,14 @@ func (s *Server) resetStats() {
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 检查服务器是否已经停止
s.stoppedMutex.Lock()
if s.stopped {
s.stoppedMutex.Unlock()
return // 服务器已经停止,直接返回
}
// 标记服务器为已停止状态
s.stopped = true
s.stoppedMutex.Unlock()
// 停止威胁域名数据库文件监听
if s.dbManager != nil {
s.dbManager.StopWatching()
}
// 发送停止信号给保存协程
close(s.saveDone)
// 最后保存一次数据
s.saveStatsData()
@@ -2527,14 +2577,9 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
// 同时使用旧日志系统记录(兼容性)
// 发送到日志处理通道(阻塞式,确保日志不会丢失)
s.logChannel <- queryLog
logger.Debug("日志发送到通道", "domain", queryLog.Domain, "result", queryLog.Result)
}
// GetStartTime 获取服务器启动时间
func (s *Server) GetStartTime() time.Time {
return s.startTime
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
@@ -2637,178 +2682,24 @@ func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resul
}
return result
}
// 如果新系统查询失败或没有数据,降级到旧系统
logger.Debug("新日志系统查询失败或无数据,降级到旧系统", "error", err, "logs", len(logs))
}
// 使用旧日志系统(兼容模式)
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
// 确保偏移量和限制值合理
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 100 // 默认返回100条日志
}
// 设置合理的上限,防止请求过多数据
if limit > 1000 {
limit = 1000
}
// 预分配切片容量,减少内存分配
var filteredLogs []QueryLog
capacity := len(s.queryLogs)
if capacity > 10000 {
capacity = 10000 // 限制最大容量,避免内存使用过高
}
filteredLogs = make([]QueryLog, 0, capacity)
// 先过滤日志
for _, log := range s.queryLogs {
// 应用结果过滤
if resultFilter != "" && log.Result != resultFilter {
continue
}
// 应用解析类型过滤
if queryType != "" && log.QueryType != queryType {
continue
}
// 应用搜索过滤
if searchTerm != "" {
// 搜索域名或客户端IP,使用strings.Contains的优化版本
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
continue
}
}
filteredLogs = append(filteredLogs, log)
}
// 排序日志
if sortField != "" {
// 使用更高效的排序方式,避免反射操作
switch sortField {
case "time":
if sortDirection == "asc" {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].Timestamp.Before(filteredLogs[j].Timestamp)
})
} else {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].Timestamp.After(filteredLogs[j].Timestamp)
})
}
case "clientIp":
if sortDirection == "asc" {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].ClientIP < filteredLogs[j].ClientIP
})
} else {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].ClientIP > filteredLogs[j].ClientIP
})
}
case "domain":
if sortDirection == "asc" {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].Domain < filteredLogs[j].Domain
})
} else {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].Domain > filteredLogs[j].Domain
})
}
case "responseTime":
if sortDirection == "asc" {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].ResponseTime < filteredLogs[j].ResponseTime
})
} else {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].ResponseTime > filteredLogs[j].ResponseTime
})
}
case "blockRule":
if sortDirection == "asc" {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].BlockRule < filteredLogs[j].BlockRule
})
} else {
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].BlockRule > filteredLogs[j].BlockRule
})
}
default:
// 默认按时间降序排序
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].Timestamp.After(filteredLogs[j].Timestamp)
})
}
} else {
// 默认按时间降序排序
sort.Slice(filteredLogs, func(i, j int) bool {
return filteredLogs[i].Timestamp.After(filteredLogs[j].Timestamp)
})
}
// 计算返回范围
start := offset
end := offset + limit
if end > len(filteredLogs) {
end = len(filteredLogs)
}
if start >= len(filteredLogs) {
return []QueryLog{} // 没有数据,返回空切片
}
// 直接返回子切片,避免不必要的内存分配
return filteredLogs[start:end]
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
switch v1 := a.(type) {
case time.Time:
v2 := b.(time.Time)
if v1.Before(v2) {
return -1
}
if v1.After(v2) {
return 1
}
return 0
case string:
v2 := b.(string)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
case int64:
v2 := b.(int64)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
default:
return 0
// 如果新系统查询失败或没有数据,返回空列表
logger.Debug("新日志系统查询失败或无数据", "error", err, "logs", len(logs))
}
// 返回空列表
return []QueryLog{}
}
// GetQueryLogsCount 获取查询日志总数
func (s *Server) GetQueryLogsCount() int {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
return len(s.queryLogs)
// 使用新日志系统获取总数
if s.logManager != nil {
stats, err := s.logManager.GetStats(log.TimeRange{})
if err == nil {
return int(stats.TotalQueries)
}
}
return 0
}
// GetQueryLogsCountWithFilter 获取带过滤条件的查询日志总数
@@ -2822,9 +2713,9 @@ func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType
}
page := log.PageParams{
Limit: 1, // 只需要总数,不需要实际数据
Limit: 1, // 只需要总数,获取 1 条数据即可
Offset: 0,
SortField: "",
SortField: "timestamp",
SortDirection: "desc",
}
@@ -2832,41 +2723,11 @@ func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType
if err == nil {
return int(total)
}
// 如果新系统查询失败,降级到旧系统
logger.Debug("新日志系统查询失败,降级到旧系统", "error", err)
// 如果新系统查询失败,返回 0
logger.Debug("新日志系统查询失败", "error", err)
}
// 使用旧日志系统(兼容模式)
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
count := 0
for _, log := range s.queryLogs {
// 应用结果过滤
if resultFilter != "" && log.Result != resultFilter {
continue
}
// 应用解析类型过滤
if queryType != "" && log.QueryType != queryType {
continue
}
// 应用搜索过滤
if searchTerm != "" {
// 搜索域名或客户端IP
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
continue
}
}
count++
}
return count
return 0
}
// GetQueryStats 获取查询统计信息
func (s *Server) GetQueryStats() map[string]interface{} {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
@@ -3237,54 +3098,6 @@ func (s *Server) loadStatsData() {
s.clientStatsMutex.Unlock()
logger.Info("统计数据加载成功")
// 加载查询日志
s.loadQueryLogs()
}
// loadQueryLogs 从文件加载查询日志
func (s *Server) loadQueryLogs() {
// 获取绝对路径
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
// 构建查询日志文件路径
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
// 检查文件是否存在
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
return
}
// 读取文件内容
data, err := ioutil.ReadFile(queryLogPath)
if err != nil {
logger.Error("读取查询日志文件失败", "error", err)
return
}
// 解析数据
var logs []QueryLog
err = json.Unmarshal(data, &logs)
if err != nil {
logger.Error("解析查询日志失败", "error", err)
return
}
// 更新查询日志
s.queryLogsMutex.Lock()
s.queryLogs = logs
// 确保日志数量不超过限制
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
s.queryLogsMutex.Unlock()
logger.Info("查询日志加载成功", "count", len(logs))
}
// processLogs 异步处理日志记录
@@ -3450,18 +3263,14 @@ func (s *Server) startAutoSave() {
}
// 初始化定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json")
// 定期保存数据
for {
select {
case <-s.saveTicker.C:
s.saveStatsData()
case <-s.saveDone:
return
}
}
}
// GetArchiveQueryEngine 获取归档查询引擎
func (s *Server) GetArchiveQueryEngine() *log.ArchiveQueryEngine {
return s.archiveQueryEngine
}
// GetArchiveManager 获取归档管理器
func (s *Server) GetArchiveManager() *log.ArchiveManager {
return s.archiveManager
}