Files
dns-server/dns/server.go
2025-11-30 13:31:16 +08:00

1398 lines
35 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
"dns-server/shield"
"github.com/miekg/dns"
)
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen time.Time
}
// ClientStats 客户端统计
type ClientStats struct {
IP string
Count int64
LastSeen time.Time
}
// IPGeolocation IP地理位置信息
type IPGeolocation struct {
Country string `json:"country"` // 国家
City string `json:"city"` // 城市
Expiry time.Time `json:"expiry"` // 缓存过期时间
}
// QueryLog 查询日志记录
type QueryLog struct {
Timestamp time.Time // 查询时间
ClientIP string // 客户端IP
Location string // IP地理位置国家 城市)
Domain string // 查询域名
QueryType string // 查询类型
ResponseTime int64 // 响应时间(ms)
Result string // 查询结果allowed, blocked, error
BlockRule string // 屏蔽规则(如果被屏蔽)
BlockType string // 屏蔽类型(如果被屏蔽)
}
// StatsData 用于持久化的统计数据结构
type StatsData struct {
Stats *Stats `json:"stats"`
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
ClientStats map[string]*ClientStats `json:"clientStats"`
HourlyStats map[string]int64 `json:"hourlyStats"`
DailyStats map[string]int64 `json:"dailyStats"`
MonthlyStats map[string]int64 `json:"monthlyStats"`
LastSaved time.Time `json:"lastSaved"`
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
server *dns.Server
tcpServer *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
resolvedDomainsMutex sync.RWMutex
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
clientStatsMutex sync.RWMutex
clientStats map[string]*ClientStats // 用于记录客户端统计
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
dailyStatsMutex sync.RWMutex
dailyStats map[string]int64 // 按天统计屏蔽数量
monthlyStatsMutex sync.RWMutex
monthlyStats map[string]int64 // 按月统计屏蔽数量
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表
maxQueryLogs int // 最大保存日志数量
saveTicker *time.Ticker // 用于定时保存数据
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// IP地理位置缓存
ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射
ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁
ipGeolocationCacheTTL time.Duration // 缓存有效期
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
AvgResponseTime float64 // 平均响应时间(ms)
TotalResponseTime int64 // 总响应时间
QueryTypes map[string]int64 // 查询类型统计
SourceIPs map[string]bool // 活跃来源IP
CpuUsage float64 // CPU使用率(%)
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
server := &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
resolver: &dns.Client{
Net: "udp",
Timeout: time.Duration(config.Timeout) * time.Millisecond,
},
ctx: ctx,
cancel: cancel,
startTime: time.Now(), // 记录服务器启动时间
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
AvgResponseTime: 0,
TotalResponseTime: 0,
QueryTypes: make(map[string]int64),
SourceIPs: make(map[string]bool),
CpuUsage: 0,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
clientStats: make(map[string]*ClientStats),
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// IP地理位置缓存初始化
ipGeolocationCache: make(map[string]*IPGeolocation),
ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时
}
// 加载已保存的统计数据
server.loadStatsData()
return server
}
// Start 启动DNS服务器
func (s *Server) Start() error {
// 重新初始化上下文和取消函数
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
s.cancel = cancel
// 重新初始化saveDone通道
s.saveDone = make(chan struct{})
// 重置stopped标志
s.stoppedMutex.Lock()
s.stopped = false
s.stoppedMutex.Unlock()
// 更新服务器启动时间
s.startTime = time.Now()
s.server = &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 保存TCP服务器实例以便在Stop方法中关闭
s.tcpServer = &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动CPU使用率监控
go s.startCpuUsageMonitor()
// 启动自动保存功能
go s.startAutoSave()
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动监听端口: %d", s.config.Port))
if err := s.tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 检查服务器是否已经停止
s.stoppedMutex.Lock()
if s.stopped {
s.stoppedMutex.Unlock()
return // 服务器已经停止,直接返回
}
// 标记服务器为已停止状态
s.stopped = true
s.stoppedMutex.Unlock()
// 发送停止信号给保存协程
close(s.saveDone)
// 最后保存一次数据
s.saveStatsData()
// 停止服务器
s.cancel()
if s.server != nil {
s.server.Shutdown()
}
if s.tcpServer != nil {
s.tcpServer.Shutdown()
}
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
// 获取来源IP
sourceIP := w.RemoteAddr().String()
// 提取IP地址部分去掉端口
if strings.HasPrefix(sourceIP, "[") {
// IPv6地址格式: [::1]:53
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
sourceIP = sourceIP[1:idx] // 去掉方括号
}
} else {
// IPv4地址格式: 127.0.0.1:53
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
sourceIP = sourceIP[:idx]
}
}
// 更新来源IP统计
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
stats.SourceIPs[sourceIP] = true
})
// 更新客户端统计
s.updateClientStats(sourceIP)
// 获取查询域名和类型
var domain string
var queryType string
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
// 获取查询类型
queryType = dns.TypeToString[r.Question[0].Qtype]
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
})
}
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "")
return
}
// 检查hosts文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
s.handleHostsResponse(w, r, ip)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
return
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(domain) {
// 获取屏蔽详情
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
blockRule, _ := blockDetails["blockRule"].(string)
blockType, _ := blockDetails["blockRuleType"].(string)
s.handleBlockedResponse(w, r, domain)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType)
return
}
// 转发到上游DNS服务器
s.forwardDNSRequest(w, r, domain)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
}
// handleHostsResponse 处理hosts文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
// 记录解析域名统计
domain := ""
if len(r.Question) > 0 {
domain = r.Question[0].Name
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
s.updateResolvedDomainStats(domain)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(customBlockIP)
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应域名不存在
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
// 尝试所有上游DNS服务器
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
if err == nil && response != nil && response.Rcode == dns.RcodeSuccess {
// 设置递归可用标志
response.RecursionAvailable = true
w.WriteMsg(response)
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return
}
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(response)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 更新被屏蔽域名计数
s.blockedDomainsMutex.Lock()
defer s.blockedDomainsMutex.Unlock()
if entry, exists := s.blockedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
// 更新统计数据
now := time.Now()
// 更新小时统计
hourKey := now.Format("2006-01-02-15")
s.hourlyStatsMutex.Lock()
s.hourlyStats[hourKey]++
s.hourlyStatsMutex.Unlock()
// 更新每日统计
dayKey := now.Format("2006-01-02")
s.dailyStatsMutex.Lock()
s.dailyStats[dayKey]++
s.dailyStatsMutex.Unlock()
// 更新每月统计
monthKey := now.Format("2006-01")
s.monthlyStatsMutex.Lock()
s.monthlyStats[monthKey]++
s.monthlyStatsMutex.Unlock()
}
// updateClientStats 更新客户端统计
func (s *Server) updateClientStats(ip string) {
s.clientStatsMutex.Lock()
defer s.clientStatsMutex.Unlock()
if entry, exists := s.clientStats[ip]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.clientStats[ip] = &ClientStats{
IP: ip,
Count: 1,
LastSeen: time.Now(),
}
}
}
// updateResolvedDomainStats 更新解析域名统计
func (s *Server) updateResolvedDomainStats(domain string) {
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
if entry, exists := s.resolvedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string) {
// 获取IP地理位置
location := s.getIpGeolocation(clientIP)
// 创建日志记录
log := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Location: location,
Domain: domain,
QueryType: queryType,
ResponseTime: responseTime,
Result: result,
BlockRule: blockRule,
BlockType: blockType,
}
// 添加到日志列表
s.queryLogsMutex.Lock()
defer s.queryLogsMutex.Unlock()
// 插入到列表开头
s.queryLogs = append([]QueryLog{log}, s.queryLogs...)
// 限制日志数量
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
}
// GetStartTime 获取服务器启动时间
func (s *Server) GetStartTime() time.Time {
return s.startTime
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 复制查询类型统计
queryTypesCopy := make(map[string]int64)
for k, v := range s.stats.QueryTypes {
queryTypesCopy[k] = v
}
// 复制来源IP统计
sourceIPsCopy := make(map[string]bool)
for ip := range s.stats.SourceIPs {
sourceIPsCopy[ip] = true
}
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
AvgResponseTime: s.stats.AvgResponseTime,
TotalResponseTime: s.stats.TotalResponseTime,
QueryTypes: queryTypesCopy,
SourceIPs: sourceIPsCopy,
CpuUsage: s.stats.CpuUsage,
}
}
// GetQueryLogs 获取查询日志
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
// 确保偏移量和限制值合理
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 100 // 默认返回100条日志
}
// 创建日志副本用于过滤和排序
var logsCopy []QueryLog
// 先过滤日志
for _, log := range s.queryLogs {
// 应用结果过滤
if resultFilter != "" && log.Result != resultFilter {
continue
}
// 应用搜索过滤
if searchTerm != "" {
// 搜索域名或客户端IP
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
continue
}
}
logsCopy = append(logsCopy, log)
}
// 排序日志
if sortField != "" {
sort.Slice(logsCopy, func(i, j int) bool {
var a, b interface{}
switch sortField {
case "time":
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
case "clientIp":
a = logsCopy[i].ClientIP
b = logsCopy[j].ClientIP
case "domain":
a = logsCopy[i].Domain
b = logsCopy[j].Domain
case "responseTime":
a = logsCopy[i].ResponseTime
b = logsCopy[j].ResponseTime
case "blockRule":
a = logsCopy[i].BlockRule
b = logsCopy[j].BlockRule
default:
// 默认按时间排序
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
}
// 根据排序方向比较
if sortDirection == "asc" {
return compareValues(a, b) < 0
}
return compareValues(a, b) > 0
})
}
// 计算返回范围
start := offset
end := offset + limit
if end > len(logsCopy) {
end = len(logsCopy)
}
if start >= len(logsCopy) {
return []QueryLog{} // 没有数据,返回空切片
}
return logsCopy[start:end]
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
switch v1 := a.(type) {
case time.Time:
v2 := b.(time.Time)
if v1.Before(v2) {
return -1
}
if v1.After(v2) {
return 1
}
return 0
case string:
v2 := b.(string)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
case int64:
v2 := b.(int64)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
default:
return 0
}
}
// GetQueryLogsCount 获取查询日志总数
func (s *Server) GetQueryLogsCount() int {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
return len(s.queryLogs)
}
// GetQueryStats 获取查询统计信息
func (s *Server) GetQueryStats() map[string]interface{} {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 计算统计数据
return map[string]interface{}{
"totalQueries": s.stats.Queries,
"blockedQueries": s.stats.Blocked,
"allowedQueries": s.stats.Allowed,
"errorQueries": s.stats.Errors,
"avgResponseTime": s.stats.AvgResponseTime,
"activeIPs": len(s.stats.SourceIPs),
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopResolvedDomains 获取TOP解析域名
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
s.resolvedDomainsMutex.RLock()
defer s.resolvedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
for _, entry := range s.resolvedDomains {
domains = append(domains, *entry)
}
// 按数量排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen.After(domains[j].LastSeen)
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopClients 获取TOP客户端列表
func (s *Server) GetTopClients(limit int) []ClientStats {
s.clientStatsMutex.RLock()
defer s.clientStatsMutex.RUnlock()
// 转换为切片
clients := make([]ClientStats, 0, len(s.clientStats))
for _, entry := range s.clientStats {
clients = append(clients, *entry)
}
// 按请求次数排序
sort.Slice(clients, func(i, j int) bool {
return clients[i].Count > clients[j].Count
})
// 返回限制数量
if len(clients) > limit {
return clients[:limit]
}
return clients
}
// GetHourlyStats 获取每小时统计数据
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}
// GetDailyStats 获取每日统计数据
func (s *Server) GetDailyStats() map[string]int64 {
s.dailyStatsMutex.RLock()
defer s.dailyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.dailyStats {
result[k] = v
}
return result
}
// GetMonthlyStats 获取每月统计数据
func (s *Server) GetMonthlyStats() map[string]int64 {
s.monthlyStatsMutex.RLock()
defer s.monthlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.monthlyStats {
result[k] = v
}
return result
}
// isPrivateIP 检测IP地址是否为内网IP
func isPrivateIP(ip string) bool {
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 检查IPv4内网地址
if ipv4 := parsedIP.To4(); ipv4 != nil {
// 10.0.0.0/8
if ipv4[0] == 10 {
return true
}
// 172.16.0.0/12
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
return true
}
// 192.168.0.0/16
if ipv4[0] == 192 && ipv4[1] == 168 {
return true
}
// 127.0.0.0/8 (localhost)
if ipv4[0] == 127 {
return true
}
// 169.254.0.0/16 (链路本地地址)
if ipv4[0] == 169 && ipv4[1] == 254 {
return true
}
return false
}
// 检查IPv6内网地址
// ::1/128 (localhost)
if parsedIP.IsLoopback() {
return true
}
// fc00::/7 (唯一本地地址)
if parsedIP[0]&0xfc == 0xfc {
return true
}
// fe80::/10 (链路本地地址)
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
return true
}
return false
}
// getIpGeolocation 获取IP地址的地理位置信息
func (s *Server) getIpGeolocation(ip string) string {
// 检查IP是否为本地或内网地址
if isPrivateIP(ip) {
return "内网 内网"
}
// 先检查缓存
s.ipGeolocationCacheMutex.RLock()
geo, exists := s.ipGeolocationCache[ip]
s.ipGeolocationCacheMutex.RUnlock()
// 如果缓存存在且未过期,直接返回
if exists && time.Now().Before(geo.Expiry) {
return fmt.Sprintf("%s %s", geo.Country, geo.City)
}
// 缓存不存在或已过期从API获取
geoInfo, err := s.fetchIpGeolocationFromAPI(ip)
if err != nil {
logger.Error("获取IP地理位置失败", "ip", ip, "error", err)
return "未知 未知"
}
// 保存到缓存
s.ipGeolocationCacheMutex.Lock()
s.ipGeolocationCache[ip] = &IPGeolocation{
Country: geoInfo["country"].(string),
City: geoInfo["city"].(string),
Expiry: time.Now().Add(s.ipGeolocationCacheTTL),
}
s.ipGeolocationCacheMutex.Unlock()
// 返回格式化的地理位置
return fmt.Sprintf("%s %s", geoInfo["country"].(string), geoInfo["city"].(string))
}
// fetchIpGeolocationFromAPI 从第三方API获取IP地理位置信息
func (s *Server) fetchIpGeolocationFromAPI(ip string) (map[string]interface{}, error) {
// 使用ip-api.com获取IP地理位置信息
url := fmt.Sprintf("http://ip-api.com/json/%s?fields=country,city", ip)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 读取响应内容
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解析JSON响应
var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}
// 检查API返回状态
status, ok := result["status"].(string)
if !ok || status != "success" {
return nil, fmt.Errorf("API返回错误状态: %v", result)
}
// 确保国家和城市字段存在
if _, ok := result["country"]; !ok {
result["country"] = "未知"
}
if _, ok := result["city"]; !ok {
result["city"] = "未知"
}
return result, nil
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
if s.config.StatsFile == "" {
return
}
// 检查文件是否存在
data, err := ioutil.ReadFile(s.config.StatsFile)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
}
return
}
var statsData StatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析统计数据失败", "error", err)
return
}
// 恢复统计数据
s.statsMutex.Lock()
if statsData.Stats != nil {
s.stats = statsData.Stats
}
s.statsMutex.Unlock()
s.blockedDomainsMutex.Lock()
if statsData.BlockedDomains != nil {
s.blockedDomains = statsData.BlockedDomains
}
s.blockedDomainsMutex.Unlock()
s.resolvedDomainsMutex.Lock()
if statsData.ResolvedDomains != nil {
s.resolvedDomains = statsData.ResolvedDomains
}
s.resolvedDomainsMutex.Unlock()
s.hourlyStatsMutex.Lock()
if statsData.HourlyStats != nil {
s.hourlyStats = statsData.HourlyStats
}
s.hourlyStatsMutex.Unlock()
s.dailyStatsMutex.Lock()
if statsData.DailyStats != nil {
s.dailyStats = statsData.DailyStats
}
s.dailyStatsMutex.Unlock()
s.monthlyStatsMutex.Lock()
if statsData.MonthlyStats != nil {
s.monthlyStats = statsData.MonthlyStats
}
s.monthlyStatsMutex.Unlock()
// 加载客户端统计数据
s.clientStatsMutex.Lock()
if statsData.ClientStats != nil {
s.clientStats = statsData.ClientStats
}
s.clientStatsMutex.Unlock()
logger.Info("统计数据加载成功")
// 加载查询日志
s.loadQueryLogs()
}
// loadQueryLogs 从文件加载查询日志
func (s *Server) loadQueryLogs() {
if s.config.StatsFile == "" {
return
}
// 获取绝对路径
statsFilePath, err := filepath.Abs(s.config.StatsFile)
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
return
}
// 构建查询日志文件路径
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
// 检查文件是否存在
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
return
}
// 读取文件内容
data, err := ioutil.ReadFile(queryLogPath)
if err != nil {
logger.Error("读取查询日志文件失败", "error", err)
return
}
// 解析数据
var logs []QueryLog
err = json.Unmarshal(data, &logs)
if err != nil {
logger.Error("解析查询日志失败", "error", err)
return
}
// 更新查询日志
s.queryLogsMutex.Lock()
s.queryLogs = logs
// 确保日志数量不超过限制
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
s.queryLogsMutex.Unlock()
logger.Info("查询日志加载成功", "count", len(logs))
}
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
if s.config.StatsFile == "" {
return
}
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs(s.config.StatsFile)
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
return
}
// 创建数据目录
statsDir := filepath.Dir(statsFilePath)
err = os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
return
}
// 收集所有统计数据
statsData := &StatsData{
Stats: s.GetStats(),
LastSaved: time.Now(),
}
// 复制域名数据
s.blockedDomainsMutex.RLock()
statsData.BlockedDomains = make(map[string]*BlockedDomain)
for k, v := range s.blockedDomains {
statsData.BlockedDomains[k] = v
}
s.blockedDomainsMutex.RUnlock()
s.resolvedDomainsMutex.RLock()
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
for k, v := range s.resolvedDomains {
statsData.ResolvedDomains[k] = v
}
s.resolvedDomainsMutex.RUnlock()
s.hourlyStatsMutex.RLock()
statsData.HourlyStats = make(map[string]int64)
for k, v := range s.hourlyStats {
statsData.HourlyStats[k] = v
}
s.hourlyStatsMutex.RUnlock()
s.dailyStatsMutex.RLock()
statsData.DailyStats = make(map[string]int64)
for k, v := range s.dailyStats {
statsData.DailyStats[k] = v
}
s.dailyStatsMutex.RUnlock()
s.monthlyStatsMutex.RLock()
statsData.MonthlyStats = make(map[string]int64)
for k, v := range s.monthlyStats {
statsData.MonthlyStats[k] = v
}
s.monthlyStatsMutex.RUnlock()
// 复制客户端统计数据
s.clientStatsMutex.RLock()
statsData.ClientStats = make(map[string]*ClientStats)
for k, v := range s.clientStats {
statsData.ClientStats[k] = v
}
s.clientStatsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化统计数据失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(statsFilePath, jsonData, 0644)
if err != nil {
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
return
}
logger.Info("统计数据保存成功", "file", statsFilePath)
// 保存查询日志到文件
s.saveQueryLogs(statsDir)
}
// saveQueryLogs 保存查询日志到文件
func (s *Server) saveQueryLogs(dataDir string) {
// 构建查询日志文件路径
queryLogPath := filepath.Join(dataDir, "querylog.json")
// 获取查询日志数据
s.queryLogsMutex.RLock()
logsCopy := make([]QueryLog, len(s.queryLogs))
copy(logsCopy, s.queryLogs)
s.queryLogsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
if err != nil {
logger.Error("序列化查询日志失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(queryLogPath, jsonData, 0644)
if err != nil {
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
return
}
logger.Info("查询日志保存成功", "file", queryLogPath)
}
// startCpuUsageMonitor 启动CPU使用率监控
func (s *Server) startCpuUsageMonitor() {
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
defer ticker.Stop()
// 初始化
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// 存储上一次的CPU时间统计
var prevIdle, prevTotal uint64
for {
select {
case <-ticker.C:
// 获取真实的系统级CPU使用率
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
if err != nil {
// 如果获取失败,使用默认值
cpuUsage = 0.0
logger.Error("获取系统CPU使用率失败", "error", err)
}
s.updateStats(func(stats *Stats) {
stats.CpuUsage = cpuUsage
})
case <-s.ctx.Done():
return
}
}
}
// getSystemCpuUsage 获取系统CPU使用率
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// 读取/proc/stat文件获取CPU统计信息
file, err := os.Open("/proc/stat")
if err != nil {
return 0, err
}
defer file.Close()
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
if err != nil {
return 0, err
}
// 计算总的CPU时间
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
idle := cpuIdle + cpuIowait
// 第一次调用时,只初始化值,不计算使用率
if *prevTotal == 0 || *prevIdle == 0 {
*prevIdle = idle
*prevTotal = total
return 0, nil
}
// 计算CPU使用率
idleDelta := idle - *prevIdle
totalDelta := total - *prevTotal
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
// 更新上一次的值
*prevIdle = idle
*prevTotal = total
return utilization, nil
}
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.StatsFile == "" || s.config.SaveInterval <= 0 {
return
}
// 设置定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", s.config.StatsFile)
// 定期保存数据
for {
select {
case <-s.saveTicker.C:
s.saveStatsData()
case <-s.saveDone:
return
}
}
}