This commit is contained in:
Alex Yang
2026-04-01 12:22:55 +08:00
parent 61789061ce
commit efebce3c39
46 changed files with 4797716 additions and 462145 deletions
+304 -107
View File
@@ -19,8 +19,10 @@ import (
"dns-server/config"
"dns-server/gfw"
"dns-server/log"
"dns-server/logger"
"dns-server/shield"
"dns-server/threat"
"github.com/miekg/dns"
)
@@ -102,6 +104,7 @@ type ServerStats struct {
// Server DNS服务器
type Server struct {
globalConfig *config.Config
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
@@ -126,15 +129,20 @@ type Server struct {
dailyStats map[string]int64 // 按天统计屏蔽数量
monthlyStatsMutex sync.RWMutex
monthlyStats map[string]int64 // 按月统计屏蔽数量
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表
maxQueryLogs int // 最大保存日志数量
logChannel chan QueryLog // 日志处理通道
saveTicker *time.Ticker // 用于定时保存数据
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// 新日志系统
logManager *log.LogManager // 日志管理器
// 旧日志系统(保留以兼容,但不再使用)
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表(已废弃)
maxQueryLogs int // 最大保存日志数量(已废弃)
logChannel chan QueryLog // 日志处理通道(已废弃)
saveTicker *time.Ticker // 用于定时保存数据(已废弃)
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止(已废弃)
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// DNS查询缓存
DnsCache *DNSCache // DNS响应缓存
@@ -152,6 +160,11 @@ type Server struct {
// DNS客户端实例池,用于并行查询
clientPool sync.Pool // 存储*dns.Client实例
// 威胁检测相关
threatEngine *threat.ThreatEngine
alertManager *threat.AlertManager
dbManager *threat.ThreatDatabaseManager
}
// Stats DNS服务器统计信息
@@ -173,22 +186,23 @@ type Stats struct {
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager, gfwConfig *config.GFWListConfig, gfwManager *gfw.GFWListManager) *Server {
func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值(分钟)
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
cacheTTL := time.Duration(globalConfig.DNS.CacheTTL) * time.Minute
// 保存间隔(秒)
saveInterval := time.Duration(config.SaveInterval) * time.Second
saveInterval := time.Duration(globalConfig.DNS.SaveInterval) * time.Second
// 最大和最小缓存TTL(分钟)
maxCacheTTL := time.Duration(config.MaxCacheTTL) * time.Minute
minCacheTTL := time.Duration(config.MinCacheTTL) * time.Minute
maxCacheTTL := time.Duration(globalConfig.DNS.MaxCacheTTL) * time.Minute
minCacheTTL := time.Duration(globalConfig.DNS.MinCacheTTL) * time.Minute
server := &Server{
config: config,
shieldConfig: shieldConfig,
globalConfig: globalConfig,
config: &globalConfig.DNS,
shieldConfig: &globalConfig.Shield,
shieldManager: shieldManager,
gfwConfig: gfwConfig,
gfwConfig: &globalConfig.GFWList,
gfwManager: gfwManager,
resolver: &dns.Client{
Net: "udp",
@@ -210,7 +224,7 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
DNSSECQueries: 0,
DNSSECSuccess: 0,
DNSSECFailed: 0,
DNSSECEnabled: config.EnableDNSSEC,
DNSSECEnabled: globalConfig.DNS.EnableDNSSEC,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
@@ -218,21 +232,21 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
logChannel: make(chan QueryLog, 1000), // 日志处理通道,缓冲区大小1000
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
logChannel: make(chan QueryLog, 100000), // 日志处理通道,缓冲区大小 100000,避免日志被丢弃
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// DNS查询缓存初始化
DnsCache: NewDNSCache(cacheTTL, config.CacheMode, config.CacheSize, config.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL),
// 初始化域名DNSSEC状态映射表
// DNS 查询缓存初始化
DnsCache: NewDNSCache(cacheTTL, globalConfig.DNS.CacheMode, globalConfig.DNS.CacheSize, globalConfig.DNS.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL),
// 初始化域名 DNSSEC 状态映射表
domainDNSSECStatus: make(map[string]bool),
// 初始化服务器状态跟踪
serverStats: make(map[string]*ServerStats),
// 初始化DNSSEC专用服务器映射
// 初始化 DNSSEC 专用服务器映射
dnssecServerMap: make(map[string]bool),
// 初始化DNS客户端实例池
// 初始化 DNS 客户端实例池
clientPool: sync.Pool{
New: func() interface{} {
return &dns.Client{
@@ -244,6 +258,15 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
},
}
// 初始化新日志系统
logManager, err := log.NewLogManager(log.DefaultConfig())
if err != nil {
logger.Error("初始化日志管理器失败", "error", err)
} else {
server.logManager = logManager
logger.Info("新日志系统初始化成功", "ringBufferSize", log.DefaultConfig().RingBufferSize, "databasePath", log.DefaultConfig().DatabasePath)
}
// 加载已保存的统计数据
server.loadStatsData()
@@ -251,6 +274,27 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
}
// initThreatDetection 初始化威胁检测相关组件
func (s *Server) initThreatDetection() {
// 从全局配置中获取威胁检测配置
threatConfig := &s.globalConfig.Threat
// 创建告警管理器
s.alertManager = threat.NewAlertManager(threatConfig)
// 加载已保存的告警
s.alertManager.LoadAlerts()
// 创建威胁域名数据库管理器
s.dbManager = threat.NewThreatDatabaseManager(threatConfig.ThreatDatabasePath)
// 加载威胁域名数据库
s.dbManager.LoadDatabase()
// 创建威胁检测引擎
s.threatEngine = threat.NewThreatEngine(threatConfig, s.alertManager, s.dbManager)
}
// Start 启动DNS服务器
func (s *Server) Start() error {
// 重新初始化上下文和取消函数
@@ -269,6 +313,9 @@ func (s *Server) Start() error {
// 更新服务器启动时间
s.startTime = time.Now()
// 初始化威胁检测相关组件
s.initThreatDetection()
s.server = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "udp",
@@ -291,8 +338,8 @@ func (s *Server) Start() error {
// 更新DNSSEC专用服务器映射
s.updateDNSSECServerMap()
// 启动日志处理协程
go s.processLogs()
// 启动日志处理协程(已移除,新日志系统使用 SQLite 存储)
// go s.processLogs()
// 启动统计数据定期重置功能(每24小时)
go func() {
@@ -403,7 +450,10 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
return
}
// 5. 转发请求到上游服务器
// 5. 威胁检测
s.checkThreatDetection(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType)
// 6. 转发请求到上游服务器
s.handleUpstreamRequest(w, r, startTime, reqInfo)
}
@@ -531,8 +581,8 @@ func (s *Server) checkRequestConditions(w dns.ResponseWriter, r *dns.Msg, startT
func (s *Server) handleLocalRules(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 本地规则匹配的响应时间极短,使用固定值1ms
const localResponseTime int64 = 1
// 检查hosts文件是否有匹配
// 检查 hosts 文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(reqInfo.domain); exists {
s.handleHostsResponse(w, r, ip)
// 使用固定的短响应时间
@@ -540,6 +590,14 @@ func (s *Server) handleLocalRules(w dns.ResponseWriter, r *dns.Msg, startTime ti
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志 - hosts 文件匹配
hostsAnswers := []DNSAnswer{{
Type: "A",
Value: ip,
TTL: 300,
}}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "allowed", "", "", false, false, true, "hosts", "无", hostsAnswers, dns.RcodeSuccess)
return true
}
@@ -671,7 +729,7 @@ func (s *Server) handleCacheResponse(w dns.ResponseWriter, r *dns.Msg, startTime
// 缓存命中的响应时间应该是极短的,使用固定值1ms而非实际处理时间
const cacheResponseTime int64 = 1
// 缓存命中的响应视为正常解析
s.updateStats(func(stats *Stats) {
stats.Allowed++
@@ -713,7 +771,9 @@ func (s *Server) handleCacheResponse(w dns.ResponseWriter, r *dns.Msg, startTime
// handleUpstreamRequest 处理上游请求
func (s *Server) handleUpstreamRequest(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) {
// 缓存未命中,处理DNS请求
logger.Debug("开始处理上游请求", "domain", reqInfo.domain, "type", reqInfo.queryType)
// 缓存未命中,处理 DNS 请求
var response *dns.Msg
var rtt time.Duration
var dnsServer string
@@ -722,6 +782,8 @@ func (s *Server) handleUpstreamRequest(w dns.ResponseWriter, r *dns.Msg, startTi
// 直接查询原始域名
response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, reqInfo.domain)
logger.Debug("上游请求返回", "domain", reqInfo.domain, "response", response != nil, "rtt", rtt)
if response != nil {
// 如果客户端请求包含EDNS记录,确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
@@ -882,15 +944,20 @@ func (s *Server) handleUpstreamRequest(w dns.ResponseWriter, r *dns.Msg, startTi
if response != nil {
realRcode = response.Rcode
}
logger.Debug("准备添加查询日志", "domain", reqInfo.domain, "result", resultType, "responseCode", realRcode)
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, resultType, "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode)
logger.Debug("查询日志已添加", "domain", reqInfo.domain)
}
// handleHostsResponse 处理hosts文件匹配的响应
// handleHostsResponse 处理 hosts 文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable,使用默认值或上游返回的值
// 不再硬编码 RecursionAvailable,使用默认值或上游返回的值
if len(r.Question) > 0 {
q := r.Question[0]
@@ -916,7 +983,7 @@ func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string
}
w.WriteMsg(response)
// 本地hosts匹配响应时间极短,使用固定值1ms
// 本地 hosts 匹配响应时间极短,使用固定值 1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Allowed++
@@ -2406,7 +2473,7 @@ func (s *Server) updateStats(update func(*Stats)) {
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) {
// 创建日志记录
log := QueryLog{
queryLog := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Domain: domain,
@@ -2424,14 +2491,44 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
ResponseCode: responseCode,
}
// 发送到日志处理通道(非阻塞
select {
case s.logChannel <- log:
// 日志发送成功
default:
// 通道已满,丢弃日志以避免阻塞请求处理
logger.Warn("日志通道已满,丢弃一条日志记录")
// 使用新日志系统记录(如果已初始化
if s.logManager != nil {
// 将 Answers 转换为 JSON 字符串
answersJSON := ""
if len(queryLog.Answers) > 0 {
data, err := json.Marshal(queryLog.Answers)
if err == nil {
answersJSON = string(data)
}
}
newLog := log.QueryLog{
Timestamp: queryLog.Timestamp,
ClientIP: queryLog.ClientIP,
Domain: queryLog.Domain,
QueryType: queryLog.QueryType,
ResponseTime: queryLog.ResponseTime,
Result: queryLog.Result,
BlockRule: queryLog.BlockRule,
BlockType: queryLog.BlockType,
FromCache: queryLog.FromCache,
DNSSEC: queryLog.DNSSEC,
EDNS: queryLog.EDNS,
DNSServer: queryLog.DNSServer,
DNSSECServer: queryLog.DNSSECServer,
Answers: answersJSON,
ResponseCode: queryLog.ResponseCode,
}
err := s.logManager.Log(newLog)
if err != nil {
logger.Error("新日志系统记录失败", "domain", queryLog.Domain, "error", err)
}
}
// 同时使用旧日志系统记录(兼容性)
// 发送到日志处理通道(阻塞式,确保日志不会丢失)
s.logChannel <- queryLog
logger.Debug("日志发送到通道", "domain", queryLog.Domain, "result", queryLog.Result)
}
// GetStartTime 获取服务器启动时间
@@ -2477,6 +2574,74 @@ func (s *Server) GetStats() *Stats {
// GetQueryLogs 获取查询日志
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm, queryType string) []QueryLog {
// 优先使用新日志系统查询(如果已初始化)
if s.logManager != nil {
logger.Debug("使用新日志系统查询", "filter", resultFilter, "search", searchTerm, "queryType", queryType)
// 转换排序字段名称以匹配数据库字段
dbSortField := sortField
if sortField == "time" {
dbSortField = "timestamp"
} else if sortField == "clientIp" {
dbSortField = "client_ip"
} else if sortField == "responseTime" {
dbSortField = "response_time"
} else if sortField == "blockRule" {
dbSortField = "block_rule"
}
filter := log.LogFilter{
Result: resultFilter,
SearchTerm: searchTerm,
QueryType: queryType,
}
page := log.PageParams{
Limit: limit,
Offset: offset,
SortField: dbSortField,
SortDirection: sortDirection,
}
logs, total, err := s.logManager.QueryLogs(filter, page)
logger.Debug("新日志系统查询结果", "logs", len(logs), "total", total, "error", err)
if err == nil && len(logs) > 0 {
// 将新日志格式转换为旧格式
result := make([]QueryLog, 0, len(logs))
for _, newLog := range logs {
// 将 JSON 字符串解析为 DNSAnswer 数组
var answers []DNSAnswer
if newLog.Answers != "" {
json.Unmarshal([]byte(newLog.Answers), &answers)
}
oldLog := QueryLog{
Timestamp: newLog.Timestamp,
ClientIP: newLog.ClientIP,
Domain: newLog.Domain,
QueryType: newLog.QueryType,
ResponseTime: newLog.ResponseTime,
Result: newLog.Result,
BlockRule: newLog.BlockRule,
BlockType: newLog.BlockType,
FromCache: newLog.FromCache,
DNSSEC: newLog.DNSSEC,
EDNS: newLog.EDNS,
DNSServer: newLog.DNSServer,
DNSSECServer: newLog.DNSSECServer,
Answers: answers,
ResponseCode: newLog.ResponseCode,
}
result = append(result, oldLog)
}
return result
}
// 如果新系统查询失败或没有数据,降级到旧系统
logger.Debug("新日志系统查询失败或无数据,降级到旧系统", "error", err, "logs", len(logs))
}
// 使用旧日志系统(兼容模式)
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
@@ -2648,6 +2813,30 @@ func (s *Server) GetQueryLogsCount() int {
// GetQueryLogsCountWithFilter 获取带过滤条件的查询日志总数
func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType string) int {
// 优先使用新日志系统查询(如果已初始化)
if s.logManager != nil {
filter := log.LogFilter{
Result: resultFilter,
SearchTerm: searchTerm,
QueryType: queryType,
}
page := log.PageParams{
Limit: 1, // 只需要总数,不需要实际数据
Offset: 0,
SortField: "",
SortDirection: "desc",
}
_, total, err := s.logManager.QueryLogs(filter, page)
if err == nil {
return int(total)
}
// 如果新系统查询失败,降级到旧系统
logger.Debug("新日志系统查询失败,降级到旧系统", "error", err)
}
// 使用旧日志系统(兼容模式)
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
@@ -2836,6 +3025,76 @@ func (s *Server) GetMonthlyStats() map[string]int64 {
return result
}
// checkThreatDetection 检查DNS查询是否存在威胁
func (s *Server) checkThreatDetection(sourceIP, domain, queryType string) {
if s.threatEngine == nil {
return
}
// 调用威胁检测引擎检查查询
alerts := s.threatEngine.CheckQuery(sourceIP, domain, queryType)
// 处理检测到的威胁
for _, alert := range alerts {
// 添加告警到告警管理器
s.alertManager.AddAlert(alert)
}
}
// GetAlerts 获取告警列表
func (s *Server) GetAlerts(limit, offset int, level string) []*threat.ThreatAlert {
if s.alertManager == nil {
return []*threat.ThreatAlert{}
}
return s.alertManager.GetAlerts(limit, offset, level)
}
// GetAlertCount 获取告警数量
func (s *Server) GetAlertCount(level string) int {
if s.alertManager == nil {
return 0
}
return s.alertManager.GetAlertCount(level)
}
// ResolveAlert 解决告警
func (s *Server) ResolveAlert(alertID, action string) bool {
if s.alertManager == nil {
return false
}
return s.alertManager.ResolveAlert(alertID, action)
}
// GetThreatDomains 获取所有威胁域名信息
func (s *Server) GetThreatDomains() []*threat.ThreatInfo {
if s.dbManager == nil {
return []*threat.ThreatInfo{}
}
return s.dbManager.GetAllThreatDomains()
}
// AddThreatDomain 添加威胁域名
func (s *Server) AddThreatDomain(threatType, name string, riskLevel int, domain string) error {
if s.dbManager == nil {
return nil
}
return s.dbManager.AddThreatDomain(threatType, name, riskLevel, domain)
}
// RemoveThreatDomain 删除威胁域名
func (s *Server) RemoveThreatDomain(domain string) error {
if s.dbManager == nil {
return nil
}
return s.dbManager.RemoveThreatDomain(domain)
}
// isPrivateIP 检测IP地址是否为内网IP
func isPrivateIP(ip string) bool {
// 解析IP地址
@@ -3029,37 +3288,6 @@ func (s *Server) loadQueryLogs() {
}
// processLogs 异步处理日志记录
func (s *Server) processLogs() {
for {
select {
case logEntry, ok := <-s.logChannel:
if !ok {
// 通道关闭,退出循环
return
}
// 加锁保护queryLogs
s.queryLogsMutex.Lock()
// 如果日志数量超过最大限制,删除最旧的日志
if len(s.queryLogs) >= s.maxQueryLogs {
// 使用切片操作保留最新的日志,避免复制整个切片
// 保留最新的s.maxQueryLogs-1条日志,然后添加新日志
s.queryLogs = s.queryLogs[len(s.queryLogs)-s.maxQueryLogs+1:]
}
// 直接添加新日志
s.queryLogs = append(s.queryLogs, logEntry)
// 解锁
s.queryLogsMutex.Unlock()
case <-s.ctx.Done():
// 上下文取消,退出循环
return
}
}
}
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
// 获取绝对路径以避免工作目录问题
@@ -3142,40 +3370,9 @@ func (s *Server) saveStatsData() {
}
logger.Info("统计数据保存成功", "file", statsFilePath)
// 保存查询日志到文件
s.saveQueryLogs(statsDir)
}
// saveQueryLogs 保存查询日志到文件
func (s *Server) saveQueryLogs(dataDir string) {
// 构建查询日志文件路径
queryLogPath := filepath.Join(dataDir, "querylog.json")
// 获取查询日志数据
s.queryLogsMutex.RLock()
logsCopy := make([]QueryLog, len(s.queryLogs))
copy(logsCopy, s.queryLogs)
s.queryLogsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
if err != nil {
logger.Error("序列化查询日志失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(queryLogPath, jsonData, 0644)
if err != nil {
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
return
}
logger.Info("查询日志保存成功", "file", queryLogPath)
}
// startCpuUsageMonitor 启动CPU使用率监控
// startCpuUsageMonitor 启动 CPU 使用率监控
func (s *Server) startCpuUsageMonitor() {
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
defer ticker.Stop()