1070 lines
27 KiB
Go
1070 lines
27 KiB
Go
package dns
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io/ioutil"
|
||
"net"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"dns-server/config"
|
||
"dns-server/logger"
|
||
"dns-server/shield"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// BlockedDomain 屏蔽域名统计
|
||
|
||
type BlockedDomain struct {
|
||
Domain string
|
||
Count int64
|
||
LastSeen time.Time
|
||
}
|
||
|
||
// ClientStats 客户端统计
|
||
|
||
type ClientStats struct {
|
||
IP string
|
||
Count int64
|
||
LastSeen time.Time
|
||
}
|
||
|
||
// QueryLog 查询日志记录
|
||
type QueryLog struct {
|
||
Timestamp time.Time // 查询时间
|
||
ClientIP string // 客户端IP
|
||
Domain string // 查询域名
|
||
QueryType string // 查询类型
|
||
ResponseTime int64 // 响应时间(ms)
|
||
Result string // 查询结果(allowed, blocked, error)
|
||
BlockRule string // 屏蔽规则(如果被屏蔽)
|
||
BlockType string // 屏蔽类型(如果被屏蔽)
|
||
}
|
||
|
||
// StatsData 用于持久化的统计数据结构
|
||
type StatsData struct {
|
||
Stats *Stats `json:"stats"`
|
||
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
|
||
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
|
||
ClientStats map[string]*ClientStats `json:"clientStats"`
|
||
HourlyStats map[string]int64 `json:"hourlyStats"`
|
||
DailyStats map[string]int64 `json:"dailyStats"`
|
||
MonthlyStats map[string]int64 `json:"monthlyStats"`
|
||
LastSaved time.Time `json:"lastSaved"`
|
||
}
|
||
|
||
// Server DNS服务器
|
||
type Server struct {
|
||
config *config.DNSConfig
|
||
shieldConfig *config.ShieldConfig
|
||
shieldManager *shield.ShieldManager
|
||
server *dns.Server
|
||
tcpServer *dns.Server
|
||
resolver *dns.Client
|
||
ctx context.Context
|
||
cancel context.CancelFunc
|
||
statsMutex sync.Mutex
|
||
stats *Stats
|
||
blockedDomainsMutex sync.RWMutex
|
||
blockedDomains map[string]*BlockedDomain
|
||
resolvedDomainsMutex sync.RWMutex
|
||
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
|
||
clientStatsMutex sync.RWMutex
|
||
clientStats map[string]*ClientStats // 用于记录客户端统计
|
||
hourlyStatsMutex sync.RWMutex
|
||
hourlyStats map[string]int64 // 按小时统计屏蔽数量
|
||
dailyStatsMutex sync.RWMutex
|
||
dailyStats map[string]int64 // 按天统计屏蔽数量
|
||
monthlyStatsMutex sync.RWMutex
|
||
monthlyStats map[string]int64 // 按月统计屏蔽数量
|
||
queryLogsMutex sync.RWMutex
|
||
queryLogs []QueryLog // 查询日志列表
|
||
maxQueryLogs int // 最大保存日志数量
|
||
saveTicker *time.Ticker // 用于定时保存数据
|
||
startTime time.Time // 服务器启动时间
|
||
saveDone chan struct{} // 用于通知保存协程停止
|
||
stopped bool // 服务器是否已经停止
|
||
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
|
||
}
|
||
|
||
// Stats DNS服务器统计信息
|
||
type Stats struct {
|
||
Queries int64
|
||
Blocked int64
|
||
Allowed int64
|
||
Errors int64
|
||
LastQuery time.Time
|
||
AvgResponseTime float64 // 平均响应时间(ms)
|
||
TotalResponseTime int64 // 总响应时间
|
||
QueryTypes map[string]int64 // 查询类型统计
|
||
SourceIPs map[string]bool // 活跃来源IP
|
||
CpuUsage float64 // CPU使用率(%)
|
||
}
|
||
|
||
// NewServer 创建DNS服务器实例
|
||
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
server := &Server{
|
||
config: config,
|
||
shieldConfig: shieldConfig,
|
||
shieldManager: shieldManager,
|
||
resolver: &dns.Client{
|
||
Net: "udp",
|
||
Timeout: time.Duration(config.Timeout) * time.Millisecond,
|
||
},
|
||
ctx: ctx,
|
||
cancel: cancel,
|
||
startTime: time.Now(), // 记录服务器启动时间
|
||
stats: &Stats{
|
||
Queries: 0,
|
||
Blocked: 0,
|
||
Allowed: 0,
|
||
Errors: 0,
|
||
AvgResponseTime: 0,
|
||
TotalResponseTime: 0,
|
||
QueryTypes: make(map[string]int64),
|
||
SourceIPs: make(map[string]bool),
|
||
CpuUsage: 0,
|
||
},
|
||
blockedDomains: make(map[string]*BlockedDomain),
|
||
resolvedDomains: make(map[string]*BlockedDomain),
|
||
clientStats: make(map[string]*ClientStats),
|
||
hourlyStats: make(map[string]int64),
|
||
dailyStats: make(map[string]int64),
|
||
monthlyStats: make(map[string]int64),
|
||
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
|
||
maxQueryLogs: 10000, // 最大保存10000条日志
|
||
saveDone: make(chan struct{}),
|
||
stopped: false, // 初始化为未停止状态
|
||
}
|
||
|
||
// 加载已保存的统计数据
|
||
server.loadStatsData()
|
||
|
||
return server
|
||
|
||
}
|
||
|
||
// Start 启动DNS服务器
|
||
func (s *Server) Start() error {
|
||
// 重新初始化上下文和取消函数
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
s.ctx = ctx
|
||
s.cancel = cancel
|
||
|
||
// 重新初始化saveDone通道
|
||
s.saveDone = make(chan struct{})
|
||
|
||
// 重置stopped标志
|
||
s.stoppedMutex.Lock()
|
||
s.stopped = false
|
||
s.stoppedMutex.Unlock()
|
||
|
||
// 更新服务器启动时间
|
||
s.startTime = time.Now()
|
||
|
||
s.server = &dns.Server{
|
||
Addr: fmt.Sprintf(":%d", s.config.Port),
|
||
Net: "udp",
|
||
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
||
}
|
||
|
||
// 保存TCP服务器实例,以便在Stop方法中关闭
|
||
s.tcpServer = &dns.Server{
|
||
Addr: fmt.Sprintf(":%d", s.config.Port),
|
||
Net: "tcp",
|
||
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
||
}
|
||
|
||
// 启动CPU使用率监控
|
||
go s.startCpuUsageMonitor()
|
||
|
||
// 启动自动保存功能
|
||
go s.startAutoSave()
|
||
|
||
// 启动UDP服务
|
||
go func() {
|
||
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
|
||
if err := s.server.ListenAndServe(); err != nil {
|
||
logger.Error("DNS UDP服务器启动失败", "error", err)
|
||
s.cancel()
|
||
}
|
||
}()
|
||
|
||
// 启动TCP服务
|
||
go func() {
|
||
logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port))
|
||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||
logger.Error("DNS TCP服务器启动失败", "error", err)
|
||
s.cancel()
|
||
}
|
||
}()
|
||
|
||
// 等待停止信号
|
||
<-s.ctx.Done()
|
||
return nil
|
||
}
|
||
|
||
// Stop 停止DNS服务器
|
||
func (s *Server) Stop() {
|
||
// 检查服务器是否已经停止
|
||
s.stoppedMutex.Lock()
|
||
if s.stopped {
|
||
s.stoppedMutex.Unlock()
|
||
return // 服务器已经停止,直接返回
|
||
}
|
||
// 标记服务器为已停止状态
|
||
s.stopped = true
|
||
s.stoppedMutex.Unlock()
|
||
|
||
// 发送停止信号给保存协程
|
||
close(s.saveDone)
|
||
|
||
// 最后保存一次数据
|
||
s.saveStatsData()
|
||
|
||
// 停止服务器
|
||
s.cancel()
|
||
if s.server != nil {
|
||
s.server.Shutdown()
|
||
}
|
||
if s.tcpServer != nil {
|
||
s.tcpServer.Shutdown()
|
||
}
|
||
logger.Info("DNS服务器已停止")
|
||
}
|
||
|
||
// handleDNSRequest 处理DNS请求
|
||
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||
startTime := time.Now()
|
||
|
||
// 获取来源IP
|
||
sourceIP := w.RemoteAddr().String()
|
||
// 提取IP地址部分,去掉端口
|
||
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
|
||
sourceIP = sourceIP[:idx]
|
||
}
|
||
|
||
// 更新来源IP统计
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Queries++
|
||
stats.LastQuery = time.Now()
|
||
stats.SourceIPs[sourceIP] = true
|
||
})
|
||
|
||
// 更新客户端统计
|
||
s.updateClientStats(sourceIP)
|
||
|
||
// 获取查询域名和类型
|
||
var domain string
|
||
var queryType string
|
||
if len(r.Question) > 0 {
|
||
domain = r.Question[0].Name
|
||
// 移除末尾的点
|
||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||
domain = domain[:len(domain)-1]
|
||
}
|
||
// 获取查询类型
|
||
queryType = dns.TypeToString[r.Question[0].Qtype]
|
||
// 更新查询类型统计
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.QueryTypes[queryType]++
|
||
})
|
||
}
|
||
|
||
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
|
||
|
||
// 只处理递归查询
|
||
if r.RecursionDesired == false {
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
response.RecursionAvailable = true
|
||
response.SetRcode(r, dns.RcodeRefused)
|
||
w.WriteMsg(response)
|
||
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
}
|
||
})
|
||
|
||
// 添加查询日志
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "")
|
||
return
|
||
}
|
||
|
||
// 检查hosts文件是否有匹配
|
||
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
|
||
s.handleHostsResponse(w, r, ip)
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
}
|
||
})
|
||
|
||
// 添加查询日志
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
|
||
return
|
||
}
|
||
|
||
// 检查是否被屏蔽
|
||
if s.shieldManager.IsBlocked(domain) {
|
||
// 获取屏蔽详情
|
||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||
blockRule, _ := blockDetails["blockRule"].(string)
|
||
blockType, _ := blockDetails["blockRuleType"].(string)
|
||
|
||
s.handleBlockedResponse(w, r, domain)
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
}
|
||
})
|
||
|
||
// 添加查询日志
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType)
|
||
return
|
||
}
|
||
|
||
// 转发到上游DNS服务器
|
||
s.forwardDNSRequest(w, r, domain)
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
}
|
||
})
|
||
|
||
// 添加查询日志
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
|
||
}
|
||
|
||
// handleHostsResponse 处理hosts文件匹配的响应
|
||
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
response.RecursionAvailable = true
|
||
|
||
if len(r.Question) > 0 {
|
||
q := r.Question[0]
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: q.Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
answer.A = net.ParseIP(ip)
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
|
||
// 记录解析域名统计
|
||
domain := ""
|
||
if len(r.Question) > 0 {
|
||
domain = r.Question[0].Name
|
||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||
domain = domain[:len(domain)-1]
|
||
}
|
||
s.updateResolvedDomainStats(domain)
|
||
}
|
||
|
||
w.WriteMsg(response)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
}
|
||
|
||
// handleBlockedResponse 处理被屏蔽的域名响应
|
||
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
|
||
|
||
// 更新被屏蔽域名统计
|
||
s.updateBlockedDomainStats(domain)
|
||
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
response.RecursionAvailable = true
|
||
|
||
// 获取屏蔽方法配置
|
||
blockMethod := "NXDOMAIN" // 默认值
|
||
customBlockIP := "" // 默认值
|
||
|
||
// 从Server结构体的shieldConfig字段获取配置
|
||
if s.shieldConfig != nil {
|
||
blockMethod = s.shieldConfig.BlockMethod
|
||
customBlockIP = s.shieldConfig.CustomBlockIP
|
||
}
|
||
|
||
// 根据屏蔽方法返回不同的响应
|
||
switch blockMethod {
|
||
case "refused":
|
||
// 返回拒绝查询响应
|
||
response.SetRcode(r, dns.RcodeRefused)
|
||
case "emptyIP":
|
||
// 返回空IP响应
|
||
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: r.Question[0].Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
answer.A = net.ParseIP("0.0.0.0") // 空IP
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
case "customIP":
|
||
// 返回自定义IP响应
|
||
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: r.Question[0].Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
answer.A = net.ParseIP(customBlockIP)
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
case "NXDOMAIN", "":
|
||
fallthrough // 默认使用NXDOMAIN
|
||
default:
|
||
// 返回NXDOMAIN响应(域名不存在)
|
||
response.SetRcode(r, dns.RcodeNameError)
|
||
}
|
||
|
||
w.WriteMsg(response)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Blocked++
|
||
})
|
||
}
|
||
|
||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||
// 尝试所有上游DNS服务器
|
||
for _, upstream := range s.config.UpstreamDNS {
|
||
response, rtt, err := s.resolver.Exchange(r, upstream)
|
||
if err == nil && response != nil && response.Rcode == dns.RcodeSuccess {
|
||
// 设置递归可用标志
|
||
response.RecursionAvailable = true
|
||
|
||
w.WriteMsg(response)
|
||
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
|
||
|
||
// 记录解析域名统计
|
||
s.updateResolvedDomainStats(domain)
|
||
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
return
|
||
}
|
||
}
|
||
|
||
// 所有上游服务器都失败,返回服务器失败错误
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
response.RecursionAvailable = true
|
||
response.SetRcode(r, dns.RcodeServerFailure)
|
||
w.WriteMsg(response)
|
||
|
||
logger.Error("DNS查询失败", "domain", domain)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Errors++
|
||
})
|
||
}
|
||
|
||
// updateBlockedDomainStats 更新被屏蔽域名统计
|
||
func (s *Server) updateBlockedDomainStats(domain string) {
|
||
// 更新被屏蔽域名计数
|
||
s.blockedDomainsMutex.Lock()
|
||
defer s.blockedDomainsMutex.Unlock()
|
||
|
||
if entry, exists := s.blockedDomains[domain]; exists {
|
||
entry.Count++
|
||
entry.LastSeen = time.Now()
|
||
} else {
|
||
s.blockedDomains[domain] = &BlockedDomain{
|
||
Domain: domain,
|
||
Count: 1,
|
||
LastSeen: time.Now(),
|
||
}
|
||
}
|
||
|
||
// 更新统计数据
|
||
now := time.Now()
|
||
|
||
// 更新小时统计
|
||
hourKey := now.Format("2006-01-02-15")
|
||
s.hourlyStatsMutex.Lock()
|
||
s.hourlyStats[hourKey]++
|
||
s.hourlyStatsMutex.Unlock()
|
||
|
||
// 更新每日统计
|
||
dayKey := now.Format("2006-01-02")
|
||
s.dailyStatsMutex.Lock()
|
||
s.dailyStats[dayKey]++
|
||
s.dailyStatsMutex.Unlock()
|
||
|
||
// 更新每月统计
|
||
monthKey := now.Format("2006-01")
|
||
s.monthlyStatsMutex.Lock()
|
||
s.monthlyStats[monthKey]++
|
||
s.monthlyStatsMutex.Unlock()
|
||
}
|
||
|
||
// updateClientStats 更新客户端统计
|
||
func (s *Server) updateClientStats(ip string) {
|
||
s.clientStatsMutex.Lock()
|
||
defer s.clientStatsMutex.Unlock()
|
||
|
||
if entry, exists := s.clientStats[ip]; exists {
|
||
entry.Count++
|
||
entry.LastSeen = time.Now()
|
||
} else {
|
||
s.clientStats[ip] = &ClientStats{
|
||
IP: ip,
|
||
Count: 1,
|
||
LastSeen: time.Now(),
|
||
}
|
||
}
|
||
}
|
||
|
||
// updateResolvedDomainStats 更新解析域名统计
|
||
func (s *Server) updateResolvedDomainStats(domain string) {
|
||
s.resolvedDomainsMutex.Lock()
|
||
defer s.resolvedDomainsMutex.Unlock()
|
||
|
||
if entry, exists := s.resolvedDomains[domain]; exists {
|
||
entry.Count++
|
||
entry.LastSeen = time.Now()
|
||
} else {
|
||
s.resolvedDomains[domain] = &BlockedDomain{
|
||
Domain: domain,
|
||
Count: 1,
|
||
LastSeen: time.Now(),
|
||
}
|
||
}
|
||
}
|
||
|
||
// updateStats 更新统计信息
|
||
func (s *Server) updateStats(update func(*Stats)) {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
update(s.stats)
|
||
}
|
||
|
||
// addQueryLog 添加查询日志
|
||
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string) {
|
||
// 创建日志记录
|
||
log := QueryLog{
|
||
Timestamp: time.Now(),
|
||
ClientIP: clientIP,
|
||
Domain: domain,
|
||
QueryType: queryType,
|
||
ResponseTime: responseTime,
|
||
Result: result,
|
||
BlockRule: blockRule,
|
||
BlockType: blockType,
|
||
}
|
||
|
||
// 添加到日志列表
|
||
s.queryLogsMutex.Lock()
|
||
defer s.queryLogsMutex.Unlock()
|
||
|
||
// 插入到列表开头
|
||
s.queryLogs = append([]QueryLog{log}, s.queryLogs...)
|
||
|
||
// 限制日志数量
|
||
if len(s.queryLogs) > s.maxQueryLogs {
|
||
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
|
||
}
|
||
}
|
||
|
||
// GetStartTime 获取服务器启动时间
|
||
func (s *Server) GetStartTime() time.Time {
|
||
return s.startTime
|
||
}
|
||
|
||
// GetStats 获取DNS服务器统计信息
|
||
func (s *Server) GetStats() *Stats {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
|
||
// 复制查询类型统计
|
||
queryTypesCopy := make(map[string]int64)
|
||
for k, v := range s.stats.QueryTypes {
|
||
queryTypesCopy[k] = v
|
||
}
|
||
|
||
// 复制来源IP统计
|
||
sourceIPsCopy := make(map[string]bool)
|
||
for ip := range s.stats.SourceIPs {
|
||
sourceIPsCopy[ip] = true
|
||
}
|
||
|
||
// 返回统计信息的副本
|
||
return &Stats{
|
||
Queries: s.stats.Queries,
|
||
Blocked: s.stats.Blocked,
|
||
Allowed: s.stats.Allowed,
|
||
Errors: s.stats.Errors,
|
||
LastQuery: s.stats.LastQuery,
|
||
AvgResponseTime: s.stats.AvgResponseTime,
|
||
TotalResponseTime: s.stats.TotalResponseTime,
|
||
QueryTypes: queryTypesCopy,
|
||
SourceIPs: sourceIPsCopy,
|
||
CpuUsage: s.stats.CpuUsage,
|
||
}
|
||
}
|
||
|
||
// GetQueryLogs 获取查询日志
|
||
func (s *Server) GetQueryLogs(limit, offset int) []QueryLog {
|
||
s.queryLogsMutex.RLock()
|
||
defer s.queryLogsMutex.RUnlock()
|
||
|
||
// 确保偏移量和限制值合理
|
||
if offset < 0 {
|
||
offset = 0
|
||
}
|
||
if limit <= 0 {
|
||
limit = 100 // 默认返回100条日志
|
||
}
|
||
|
||
// 计算返回范围
|
||
start := offset
|
||
end := offset + limit
|
||
if end > len(s.queryLogs) {
|
||
end = len(s.queryLogs)
|
||
}
|
||
if start >= len(s.queryLogs) {
|
||
return []QueryLog{} // 没有数据,返回空切片
|
||
}
|
||
|
||
return s.queryLogs[start:end]
|
||
}
|
||
|
||
// GetQueryLogsCount 获取查询日志总数
|
||
func (s *Server) GetQueryLogsCount() int {
|
||
s.queryLogsMutex.RLock()
|
||
defer s.queryLogsMutex.RUnlock()
|
||
return len(s.queryLogs)
|
||
}
|
||
|
||
// GetQueryStats 获取查询统计信息
|
||
func (s *Server) GetQueryStats() map[string]interface{} {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
|
||
// 计算统计数据
|
||
return map[string]interface{}{
|
||
"totalQueries": s.stats.Queries,
|
||
"blockedQueries": s.stats.Blocked,
|
||
"allowedQueries": s.stats.Allowed,
|
||
"errorQueries": s.stats.Errors,
|
||
"avgResponseTime": s.stats.AvgResponseTime,
|
||
"activeIPs": len(s.stats.SourceIPs),
|
||
}
|
||
}
|
||
|
||
// GetTopBlockedDomains 获取TOP屏蔽域名列表
|
||
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
|
||
s.blockedDomainsMutex.RLock()
|
||
defer s.blockedDomainsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
|
||
for _, entry := range s.blockedDomains {
|
||
domains = append(domains, *entry)
|
||
}
|
||
|
||
// 按计数排序
|
||
sort.Slice(domains, func(i, j int) bool {
|
||
return domains[i].Count > domains[j].Count
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(domains) > limit {
|
||
return domains[:limit]
|
||
}
|
||
return domains
|
||
}
|
||
|
||
// GetTopResolvedDomains 获取TOP解析域名
|
||
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
|
||
s.resolvedDomainsMutex.RLock()
|
||
defer s.resolvedDomainsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
|
||
for _, entry := range s.resolvedDomains {
|
||
domains = append(domains, *entry)
|
||
}
|
||
|
||
// 按数量排序
|
||
sort.Slice(domains, func(i, j int) bool {
|
||
return domains[i].Count > domains[j].Count
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(domains) > limit {
|
||
return domains[:limit]
|
||
}
|
||
return domains
|
||
}
|
||
|
||
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
|
||
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
|
||
s.blockedDomainsMutex.RLock()
|
||
defer s.blockedDomainsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
|
||
for _, entry := range s.blockedDomains {
|
||
domains = append(domains, *entry)
|
||
}
|
||
|
||
// 按时间排序
|
||
sort.Slice(domains, func(i, j int) bool {
|
||
return domains[i].LastSeen.After(domains[j].LastSeen)
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(domains) > limit {
|
||
return domains[:limit]
|
||
}
|
||
return domains
|
||
}
|
||
|
||
// GetTopClients 获取TOP客户端列表
|
||
func (s *Server) GetTopClients(limit int) []ClientStats {
|
||
s.clientStatsMutex.RLock()
|
||
defer s.clientStatsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
clients := make([]ClientStats, 0, len(s.clientStats))
|
||
for _, entry := range s.clientStats {
|
||
clients = append(clients, *entry)
|
||
}
|
||
|
||
// 按请求次数排序
|
||
sort.Slice(clients, func(i, j int) bool {
|
||
return clients[i].Count > clients[j].Count
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(clients) > limit {
|
||
return clients[:limit]
|
||
}
|
||
return clients
|
||
}
|
||
|
||
// GetHourlyStats 获取每小时统计数据
|
||
func (s *Server) GetHourlyStats() map[string]int64 {
|
||
s.hourlyStatsMutex.RLock()
|
||
defer s.hourlyStatsMutex.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]int64)
|
||
for k, v := range s.hourlyStats {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// GetDailyStats 获取每日统计数据
|
||
func (s *Server) GetDailyStats() map[string]int64 {
|
||
s.dailyStatsMutex.RLock()
|
||
defer s.dailyStatsMutex.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]int64)
|
||
for k, v := range s.dailyStats {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// GetMonthlyStats 获取每月统计数据
|
||
func (s *Server) GetMonthlyStats() map[string]int64 {
|
||
s.monthlyStatsMutex.RLock()
|
||
defer s.monthlyStatsMutex.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]int64)
|
||
for k, v := range s.monthlyStats {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// loadStatsData 从文件加载统计数据
|
||
func (s *Server) loadStatsData() {
|
||
if s.config.StatsFile == "" {
|
||
return
|
||
}
|
||
|
||
// 检查文件是否存在
|
||
data, err := ioutil.ReadFile(s.config.StatsFile)
|
||
if err != nil {
|
||
if !os.IsNotExist(err) {
|
||
logger.Error("读取统计数据文件失败", "error", err)
|
||
}
|
||
return
|
||
}
|
||
|
||
var statsData StatsData
|
||
err = json.Unmarshal(data, &statsData)
|
||
if err != nil {
|
||
logger.Error("解析统计数据失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 恢复统计数据
|
||
s.statsMutex.Lock()
|
||
if statsData.Stats != nil {
|
||
s.stats = statsData.Stats
|
||
}
|
||
s.statsMutex.Unlock()
|
||
|
||
s.blockedDomainsMutex.Lock()
|
||
if statsData.BlockedDomains != nil {
|
||
s.blockedDomains = statsData.BlockedDomains
|
||
}
|
||
s.blockedDomainsMutex.Unlock()
|
||
|
||
s.resolvedDomainsMutex.Lock()
|
||
if statsData.ResolvedDomains != nil {
|
||
s.resolvedDomains = statsData.ResolvedDomains
|
||
}
|
||
s.resolvedDomainsMutex.Unlock()
|
||
|
||
s.hourlyStatsMutex.Lock()
|
||
if statsData.HourlyStats != nil {
|
||
s.hourlyStats = statsData.HourlyStats
|
||
}
|
||
s.hourlyStatsMutex.Unlock()
|
||
|
||
s.dailyStatsMutex.Lock()
|
||
if statsData.DailyStats != nil {
|
||
s.dailyStats = statsData.DailyStats
|
||
}
|
||
s.dailyStatsMutex.Unlock()
|
||
|
||
s.monthlyStatsMutex.Lock()
|
||
if statsData.MonthlyStats != nil {
|
||
s.monthlyStats = statsData.MonthlyStats
|
||
}
|
||
s.monthlyStatsMutex.Unlock()
|
||
|
||
// 加载客户端统计数据
|
||
s.clientStatsMutex.Lock()
|
||
if statsData.ClientStats != nil {
|
||
s.clientStats = statsData.ClientStats
|
||
}
|
||
s.clientStatsMutex.Unlock()
|
||
|
||
logger.Info("统计数据加载成功")
|
||
}
|
||
|
||
// saveStatsData 保存统计数据到文件
|
||
func (s *Server) saveStatsData() {
|
||
if s.config.StatsFile == "" {
|
||
return
|
||
}
|
||
|
||
// 获取绝对路径以避免工作目录问题
|
||
statsFilePath, err := filepath.Abs(s.config.StatsFile)
|
||
if err != nil {
|
||
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
|
||
return
|
||
}
|
||
|
||
// 创建数据目录
|
||
statsDir := filepath.Dir(statsFilePath)
|
||
err = os.MkdirAll(statsDir, 0755)
|
||
if err != nil {
|
||
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
|
||
return
|
||
}
|
||
|
||
// 收集所有统计数据
|
||
statsData := &StatsData{
|
||
Stats: s.GetStats(),
|
||
LastSaved: time.Now(),
|
||
}
|
||
|
||
// 复制域名数据
|
||
s.blockedDomainsMutex.RLock()
|
||
statsData.BlockedDomains = make(map[string]*BlockedDomain)
|
||
for k, v := range s.blockedDomains {
|
||
statsData.BlockedDomains[k] = v
|
||
}
|
||
s.blockedDomainsMutex.RUnlock()
|
||
|
||
s.resolvedDomainsMutex.RLock()
|
||
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
|
||
for k, v := range s.resolvedDomains {
|
||
statsData.ResolvedDomains[k] = v
|
||
}
|
||
s.resolvedDomainsMutex.RUnlock()
|
||
|
||
s.hourlyStatsMutex.RLock()
|
||
statsData.HourlyStats = make(map[string]int64)
|
||
for k, v := range s.hourlyStats {
|
||
statsData.HourlyStats[k] = v
|
||
}
|
||
s.hourlyStatsMutex.RUnlock()
|
||
|
||
s.dailyStatsMutex.RLock()
|
||
statsData.DailyStats = make(map[string]int64)
|
||
for k, v := range s.dailyStats {
|
||
statsData.DailyStats[k] = v
|
||
}
|
||
s.dailyStatsMutex.RUnlock()
|
||
|
||
s.monthlyStatsMutex.RLock()
|
||
statsData.MonthlyStats = make(map[string]int64)
|
||
for k, v := range s.monthlyStats {
|
||
statsData.MonthlyStats[k] = v
|
||
}
|
||
s.monthlyStatsMutex.RUnlock()
|
||
|
||
// 复制客户端统计数据
|
||
s.clientStatsMutex.RLock()
|
||
statsData.ClientStats = make(map[string]*ClientStats)
|
||
for k, v := range s.clientStats {
|
||
statsData.ClientStats[k] = v
|
||
}
|
||
s.clientStatsMutex.RUnlock()
|
||
|
||
// 序列化数据
|
||
jsonData, err := json.MarshalIndent(statsData, "", " ")
|
||
if err != nil {
|
||
logger.Error("序列化统计数据失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 写入文件
|
||
err = os.WriteFile(statsFilePath, jsonData, 0644)
|
||
if err != nil {
|
||
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
|
||
return
|
||
}
|
||
|
||
logger.Info("统计数据保存成功", "file", statsFilePath)
|
||
}
|
||
|
||
// startCpuUsageMonitor 启动CPU使用率监控
|
||
func (s *Server) startCpuUsageMonitor() {
|
||
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
|
||
defer ticker.Stop()
|
||
|
||
// 初始化
|
||
var memStats runtime.MemStats
|
||
runtime.ReadMemStats(&memStats)
|
||
|
||
// 存储上一次的CPU时间统计
|
||
var prevIdle, prevTotal uint64
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
// 获取真实的系统级CPU使用率
|
||
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
|
||
if err != nil {
|
||
// 如果获取失败,使用默认值
|
||
cpuUsage = 0.0
|
||
logger.Error("获取系统CPU使用率失败", "error", err)
|
||
}
|
||
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.CpuUsage = cpuUsage
|
||
})
|
||
case <-s.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// getSystemCpuUsage 获取系统CPU使用率
|
||
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
|
||
// 读取/proc/stat文件获取CPU统计信息
|
||
file, err := os.Open("/proc/stat")
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
defer file.Close()
|
||
|
||
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
|
||
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
|
||
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
// 计算总的CPU时间
|
||
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
|
||
idle := cpuIdle + cpuIowait
|
||
|
||
// 第一次调用时,只初始化值,不计算使用率
|
||
if *prevTotal == 0 || *prevIdle == 0 {
|
||
*prevIdle = idle
|
||
*prevTotal = total
|
||
return 0, nil
|
||
}
|
||
|
||
// 计算CPU使用率
|
||
idleDelta := idle - *prevIdle
|
||
totalDelta := total - *prevTotal
|
||
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
|
||
|
||
// 更新上一次的值
|
||
*prevIdle = idle
|
||
*prevTotal = total
|
||
|
||
return utilization, nil
|
||
}
|
||
|
||
// startAutoSave 启动自动保存功能
|
||
func (s *Server) startAutoSave() {
|
||
if s.config.StatsFile == "" || s.config.SaveInterval <= 0 {
|
||
return
|
||
}
|
||
|
||
// 设置定时器
|
||
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
|
||
defer s.saveTicker.Stop()
|
||
|
||
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", s.config.StatsFile)
|
||
|
||
// 定期保存数据
|
||
for {
|
||
select {
|
||
case <-s.saveTicker.C:
|
||
s.saveStatsData()
|
||
case <-s.saveDone:
|
||
return
|
||
}
|
||
}
|
||
}
|