增加威胁域名审计
This commit is contained in:
+111
-302
@@ -131,18 +131,9 @@ type Server struct {
|
||||
monthlyStats map[string]int64 // 按月统计屏蔽数量
|
||||
|
||||
// 新日志系统
|
||||
logManager *log.LogManager // 日志管理器
|
||||
|
||||
// 旧日志系统(保留以兼容,但不再使用)
|
||||
queryLogsMutex sync.RWMutex
|
||||
queryLogs []QueryLog // 查询日志列表(已废弃)
|
||||
maxQueryLogs int // 最大保存日志数量(已废弃)
|
||||
logChannel chan QueryLog // 日志处理通道(已废弃)
|
||||
saveTicker *time.Ticker // 用于定时保存数据(已废弃)
|
||||
startTime time.Time // 服务器启动时间
|
||||
saveDone chan struct{} // 用于通知保存协程停止(已废弃)
|
||||
stopped bool // 服务器是否已经停止
|
||||
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
|
||||
logManager *log.LogManager // 日志管理器
|
||||
archiveManager *log.ArchiveManager // 归档管理器
|
||||
archiveQueryEngine *log.ArchiveQueryEngine // 归档查询引擎
|
||||
|
||||
// DNS查询缓存
|
||||
DnsCache *DNSCache // DNS响应缓存
|
||||
@@ -210,7 +201,6 @@ func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager,
|
||||
},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
startTime: time.Now(), // 记录服务器启动时间
|
||||
stats: &Stats{
|
||||
Queries: 0,
|
||||
Blocked: 0,
|
||||
@@ -232,11 +222,6 @@ func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager,
|
||||
hourlyStats: make(map[string]int64),
|
||||
dailyStats: make(map[string]int64),
|
||||
monthlyStats: make(map[string]int64),
|
||||
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
|
||||
maxQueryLogs: 10000, // 最大保存10000条日志
|
||||
logChannel: make(chan QueryLog, 100000), // 日志处理通道,缓冲区大小 100000,避免日志被丢弃
|
||||
saveDone: make(chan struct{}),
|
||||
stopped: false, // 初始化为未停止状态
|
||||
|
||||
// DNS 查询缓存初始化
|
||||
DnsCache: NewDNSCache(cacheTTL, globalConfig.DNS.CacheMode, globalConfig.DNS.CacheSize, globalConfig.DNS.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL),
|
||||
@@ -267,6 +252,49 @@ func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager,
|
||||
logger.Info("新日志系统初始化成功", "ringBufferSize", log.DefaultConfig().RingBufferSize, "databasePath", log.DefaultConfig().DatabasePath)
|
||||
}
|
||||
|
||||
// 初始化归档管理器
|
||||
if globalConfig.QueryLog.ArchiveEnabled {
|
||||
// 转换为 log.QueryLogConfig
|
||||
logConfig := &log.QueryLogConfig{
|
||||
Enabled: globalConfig.QueryLog.Enabled,
|
||||
RingBufferSize: globalConfig.QueryLog.RingBufferSize,
|
||||
DatabasePath: globalConfig.QueryLog.DatabasePath,
|
||||
MaxDatabaseSizeMB: globalConfig.QueryLog.MaxDatabaseSizeMB,
|
||||
EnableWAL: globalConfig.QueryLog.EnableWAL,
|
||||
ArchiveEnabled: globalConfig.QueryLog.ArchiveEnabled,
|
||||
ArchiveDir: globalConfig.QueryLog.ArchiveDir,
|
||||
ArchivePrefix: globalConfig.QueryLog.ArchivePrefix,
|
||||
CompressionLevel: globalConfig.QueryLog.CompressionLevel,
|
||||
RetentionDays: globalConfig.QueryLog.RetentionDays,
|
||||
RetentionMonths: globalConfig.QueryLog.RetentionMonths,
|
||||
QueryTimeout: globalConfig.QueryLog.QueryTimeout,
|
||||
EnableCache: globalConfig.QueryLog.EnableCache,
|
||||
CacheTTL: globalConfig.QueryLog.CacheTTL,
|
||||
}
|
||||
|
||||
archiveManager, err := log.NewArchiveManager(logConfig, globalConfig.QueryLog.DatabasePath)
|
||||
if err != nil {
|
||||
logger.Error("初始化归档管理器失败", "error", err)
|
||||
} else {
|
||||
server.archiveManager = archiveManager
|
||||
logger.Info("归档管理器初始化成功", "archiveDir", globalConfig.QueryLog.ArchiveDir)
|
||||
}
|
||||
|
||||
// 初始化归档查询引擎
|
||||
if logManager != nil {
|
||||
sqliteStore := logManager.GetSQLiteStore()
|
||||
if sqliteStore != nil {
|
||||
archiveQueryEngine, err := log.NewArchiveQueryEngine(sqliteStore, archiveManager, logConfig)
|
||||
if err != nil {
|
||||
logger.Error("初始化归档查询引擎失败", "error", err)
|
||||
} else {
|
||||
server.archiveQueryEngine = archiveQueryEngine
|
||||
logger.Info("归档查询引擎初始化成功")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加载已保存的统计数据
|
||||
server.loadStatsData()
|
||||
|
||||
@@ -291,6 +319,11 @@ func (s *Server) initThreatDetection() {
|
||||
// 加载威胁域名数据库
|
||||
s.dbManager.LoadDatabase()
|
||||
|
||||
// 启动文件监听,自动检测数据库变更
|
||||
if err := s.dbManager.StartWatching(); err != nil {
|
||||
logger.Warn("启动威胁域名数据库监听失败", "error", err)
|
||||
}
|
||||
|
||||
// 创建威胁检测引擎
|
||||
s.threatEngine = threat.NewThreatEngine(threatConfig, s.alertManager, s.dbManager)
|
||||
}
|
||||
@@ -303,15 +336,9 @@ func (s *Server) Start() error {
|
||||
s.cancel = cancel
|
||||
|
||||
// 重新初始化saveDone通道
|
||||
s.saveDone = make(chan struct{})
|
||||
|
||||
// 重置stopped标志
|
||||
s.stoppedMutex.Lock()
|
||||
s.stopped = false
|
||||
s.stoppedMutex.Unlock()
|
||||
|
||||
// 更新服务器启动时间
|
||||
s.startTime = time.Now()
|
||||
|
||||
// 初始化威胁检测相关组件
|
||||
s.initThreatDetection()
|
||||
@@ -341,7 +368,7 @@ func (s *Server) Start() error {
|
||||
// 启动日志处理协程(已移除,新日志系统使用 SQLite 存储)
|
||||
// go s.processLogs()
|
||||
|
||||
// 启动统计数据定期重置功能(每24小时)
|
||||
// 启动统计数据定期重置功能(每 24 小时)
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
@@ -355,6 +382,32 @@ func (s *Server) Start() error {
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动归档监控和清理任务
|
||||
if s.archiveManager != nil {
|
||||
// 启动归档监控
|
||||
s.archiveManager.StartWatching()
|
||||
|
||||
// 启动定期清理任务(每天执行)
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if deleted, err := s.archiveManager.CleanupOldArchives(); err != nil {
|
||||
logger.Error("清理归档失败", "error", err)
|
||||
} else if deleted > 0 {
|
||||
logger.Info("清理归档完成", "deleted", deleted)
|
||||
}
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("归档监控和清理任务已启动")
|
||||
}
|
||||
|
||||
// 启动UDP服务
|
||||
go func() {
|
||||
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
|
||||
@@ -402,17 +455,14 @@ func (s *Server) resetStats() {
|
||||
// Stop 停止DNS服务器
|
||||
func (s *Server) Stop() {
|
||||
// 检查服务器是否已经停止
|
||||
s.stoppedMutex.Lock()
|
||||
if s.stopped {
|
||||
s.stoppedMutex.Unlock()
|
||||
return // 服务器已经停止,直接返回
|
||||
}
|
||||
// 标记服务器为已停止状态
|
||||
s.stopped = true
|
||||
s.stoppedMutex.Unlock()
|
||||
|
||||
// 停止威胁域名数据库文件监听
|
||||
if s.dbManager != nil {
|
||||
s.dbManager.StopWatching()
|
||||
}
|
||||
|
||||
// 发送停止信号给保存协程
|
||||
close(s.saveDone)
|
||||
|
||||
// 最后保存一次数据
|
||||
s.saveStatsData()
|
||||
@@ -2527,14 +2577,9 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
|
||||
|
||||
// 同时使用旧日志系统记录(兼容性)
|
||||
// 发送到日志处理通道(阻塞式,确保日志不会丢失)
|
||||
s.logChannel <- queryLog
|
||||
logger.Debug("日志发送到通道", "domain", queryLog.Domain, "result", queryLog.Result)
|
||||
}
|
||||
|
||||
// GetStartTime 获取服务器启动时间
|
||||
func (s *Server) GetStartTime() time.Time {
|
||||
return s.startTime
|
||||
}
|
||||
|
||||
// GetStats 获取DNS服务器统计信息
|
||||
func (s *Server) GetStats() *Stats {
|
||||
@@ -2637,178 +2682,24 @@ func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resul
|
||||
}
|
||||
return result
|
||||
}
|
||||
// 如果新系统查询失败或没有数据,降级到旧系统
|
||||
logger.Debug("新日志系统查询失败或无数据,降级到旧系统", "error", err, "logs", len(logs))
|
||||
}
|
||||
|
||||
// 使用旧日志系统(兼容模式)
|
||||
s.queryLogsMutex.RLock()
|
||||
defer s.queryLogsMutex.RUnlock()
|
||||
|
||||
// 确保偏移量和限制值合理
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 100 // 默认返回100条日志
|
||||
}
|
||||
// 设置合理的上限,防止请求过多数据
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
// 预分配切片容量,减少内存分配
|
||||
var filteredLogs []QueryLog
|
||||
capacity := len(s.queryLogs)
|
||||
if capacity > 10000 {
|
||||
capacity = 10000 // 限制最大容量,避免内存使用过高
|
||||
}
|
||||
filteredLogs = make([]QueryLog, 0, capacity)
|
||||
|
||||
// 先过滤日志
|
||||
for _, log := range s.queryLogs {
|
||||
// 应用结果过滤
|
||||
if resultFilter != "" && log.Result != resultFilter {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用解析类型过滤
|
||||
if queryType != "" && log.QueryType != queryType {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用搜索过滤
|
||||
if searchTerm != "" {
|
||||
// 搜索域名或客户端IP,使用strings.Contains的优化版本
|
||||
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
filteredLogs = append(filteredLogs, log)
|
||||
}
|
||||
|
||||
// 排序日志
|
||||
if sortField != "" {
|
||||
// 使用更高效的排序方式,避免反射操作
|
||||
switch sortField {
|
||||
case "time":
|
||||
if sortDirection == "asc" {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].Timestamp.Before(filteredLogs[j].Timestamp)
|
||||
})
|
||||
} else {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].Timestamp.After(filteredLogs[j].Timestamp)
|
||||
})
|
||||
}
|
||||
case "clientIp":
|
||||
if sortDirection == "asc" {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].ClientIP < filteredLogs[j].ClientIP
|
||||
})
|
||||
} else {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].ClientIP > filteredLogs[j].ClientIP
|
||||
})
|
||||
}
|
||||
case "domain":
|
||||
if sortDirection == "asc" {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].Domain < filteredLogs[j].Domain
|
||||
})
|
||||
} else {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].Domain > filteredLogs[j].Domain
|
||||
})
|
||||
}
|
||||
case "responseTime":
|
||||
if sortDirection == "asc" {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].ResponseTime < filteredLogs[j].ResponseTime
|
||||
})
|
||||
} else {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].ResponseTime > filteredLogs[j].ResponseTime
|
||||
})
|
||||
}
|
||||
case "blockRule":
|
||||
if sortDirection == "asc" {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].BlockRule < filteredLogs[j].BlockRule
|
||||
})
|
||||
} else {
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].BlockRule > filteredLogs[j].BlockRule
|
||||
})
|
||||
}
|
||||
default:
|
||||
// 默认按时间降序排序
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].Timestamp.After(filteredLogs[j].Timestamp)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// 默认按时间降序排序
|
||||
sort.Slice(filteredLogs, func(i, j int) bool {
|
||||
return filteredLogs[i].Timestamp.After(filteredLogs[j].Timestamp)
|
||||
})
|
||||
}
|
||||
|
||||
// 计算返回范围
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if end > len(filteredLogs) {
|
||||
end = len(filteredLogs)
|
||||
}
|
||||
if start >= len(filteredLogs) {
|
||||
return []QueryLog{} // 没有数据,返回空切片
|
||||
}
|
||||
|
||||
// 直接返回子切片,避免不必要的内存分配
|
||||
return filteredLogs[start:end]
|
||||
}
|
||||
|
||||
// compareValues 比较两个值
|
||||
func compareValues(a, b interface{}) int {
|
||||
switch v1 := a.(type) {
|
||||
case time.Time:
|
||||
v2 := b.(time.Time)
|
||||
if v1.Before(v2) {
|
||||
return -1
|
||||
}
|
||||
if v1.After(v2) {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case string:
|
||||
v2 := b.(string)
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
}
|
||||
if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case int64:
|
||||
v2 := b.(int64)
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
}
|
||||
if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
// 如果新系统查询失败或没有数据,返回空列表
|
||||
logger.Debug("新日志系统查询失败或无数据", "error", err, "logs", len(logs))
|
||||
}
|
||||
|
||||
// 返回空列表
|
||||
return []QueryLog{}
|
||||
}
|
||||
|
||||
// GetQueryLogsCount 获取查询日志总数
|
||||
func (s *Server) GetQueryLogsCount() int {
|
||||
s.queryLogsMutex.RLock()
|
||||
defer s.queryLogsMutex.RUnlock()
|
||||
return len(s.queryLogs)
|
||||
// 使用新日志系统获取总数
|
||||
if s.logManager != nil {
|
||||
stats, err := s.logManager.GetStats(log.TimeRange{})
|
||||
if err == nil {
|
||||
return int(stats.TotalQueries)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQueryLogsCountWithFilter 获取带过滤条件的查询日志总数
|
||||
@@ -2822,9 +2713,9 @@ func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType
|
||||
}
|
||||
|
||||
page := log.PageParams{
|
||||
Limit: 1, // 只需要总数,不需要实际数据
|
||||
Limit: 1, // 只需要总数,获取 1 条数据即可
|
||||
Offset: 0,
|
||||
SortField: "",
|
||||
SortField: "timestamp",
|
||||
SortDirection: "desc",
|
||||
}
|
||||
|
||||
@@ -2832,41 +2723,11 @@ func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType
|
||||
if err == nil {
|
||||
return int(total)
|
||||
}
|
||||
// 如果新系统查询失败,降级到旧系统
|
||||
logger.Debug("新日志系统查询失败,降级到旧系统", "error", err)
|
||||
// 如果新系统查询失败,返回 0
|
||||
logger.Debug("新日志系统查询失败", "error", err)
|
||||
}
|
||||
|
||||
// 使用旧日志系统(兼容模式)
|
||||
s.queryLogsMutex.RLock()
|
||||
defer s.queryLogsMutex.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, log := range s.queryLogs {
|
||||
// 应用结果过滤
|
||||
if resultFilter != "" && log.Result != resultFilter {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用解析类型过滤
|
||||
if queryType != "" && log.QueryType != queryType {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用搜索过滤
|
||||
if searchTerm != "" {
|
||||
// 搜索域名或客户端IP
|
||||
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetQueryStats 获取查询统计信息
|
||||
func (s *Server) GetQueryStats() map[string]interface{} {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
@@ -3237,54 +3098,6 @@ func (s *Server) loadStatsData() {
|
||||
s.clientStatsMutex.Unlock()
|
||||
|
||||
logger.Info("统计数据加载成功")
|
||||
|
||||
// 加载查询日志
|
||||
s.loadQueryLogs()
|
||||
}
|
||||
|
||||
// loadQueryLogs 从文件加载查询日志
|
||||
func (s *Server) loadQueryLogs() {
|
||||
// 获取绝对路径
|
||||
statsFilePath, err := filepath.Abs("data/stats.json")
|
||||
if err != nil {
|
||||
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建查询日志文件路径
|
||||
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
|
||||
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
data, err := ioutil.ReadFile(queryLogPath)
|
||||
if err != nil {
|
||||
logger.Error("读取查询日志文件失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析数据
|
||||
var logs []QueryLog
|
||||
err = json.Unmarshal(data, &logs)
|
||||
if err != nil {
|
||||
logger.Error("解析查询日志失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新查询日志
|
||||
s.queryLogsMutex.Lock()
|
||||
s.queryLogs = logs
|
||||
// 确保日志数量不超过限制
|
||||
if len(s.queryLogs) > s.maxQueryLogs {
|
||||
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
|
||||
}
|
||||
s.queryLogsMutex.Unlock()
|
||||
|
||||
logger.Info("查询日志加载成功", "count", len(logs))
|
||||
}
|
||||
|
||||
// processLogs 异步处理日志记录
|
||||
@@ -3450,18 +3263,14 @@ func (s *Server) startAutoSave() {
|
||||
}
|
||||
|
||||
// 初始化定时器
|
||||
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
|
||||
defer s.saveTicker.Stop()
|
||||
|
||||
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json")
|
||||
|
||||
// 定期保存数据
|
||||
for {
|
||||
select {
|
||||
case <-s.saveTicker.C:
|
||||
s.saveStatsData()
|
||||
case <-s.saveDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetArchiveQueryEngine 获取归档查询引擎
|
||||
func (s *Server) GetArchiveQueryEngine() *log.ArchiveQueryEngine {
|
||||
return s.archiveQueryEngine
|
||||
}
|
||||
|
||||
// GetArchiveManager 获取归档管理器
|
||||
func (s *Server) GetArchiveManager() *log.ArchiveManager {
|
||||
return s.archiveManager
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user