revert
This commit is contained in:
578
dns/server.go
578
dns/server.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -37,6 +38,27 @@ type ClientStats struct {
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// IPGeolocation IP地理位置信息
|
||||
type IPGeolocation struct {
|
||||
Country string `json:"country"` // 国家
|
||||
City string `json:"city"` // 城市
|
||||
Expiry time.Time `json:"expiry"` // 缓存过期时间
|
||||
}
|
||||
|
||||
// QueryLog 查询日志记录
|
||||
type QueryLog struct {
|
||||
Timestamp time.Time // 查询时间
|
||||
ClientIP string // 客户端IP
|
||||
Location string // IP地理位置(国家 城市)
|
||||
Domain string // 查询域名
|
||||
QueryType string // 查询类型
|
||||
ResponseTime int64 // 响应时间(ms)
|
||||
Result string // 查询结果(allowed, blocked, error)
|
||||
BlockRule string // 屏蔽规则(如果被屏蔽)
|
||||
BlockType string // 屏蔽类型(如果被屏蔽)
|
||||
FromCache bool // 是否来自缓存
|
||||
}
|
||||
|
||||
// StatsData 用于持久化的统计数据结构
|
||||
type StatsData struct {
|
||||
Stats *Stats `json:"stats"`
|
||||
@@ -73,11 +95,22 @@ type Server struct {
|
||||
dailyStats map[string]int64 // 按天统计屏蔽数量
|
||||
monthlyStatsMutex sync.RWMutex
|
||||
monthlyStats map[string]int64 // 按月统计屏蔽数量
|
||||
saveTicker *time.Ticker // 用于定时保存数据
|
||||
startTime time.Time // 服务器启动时间
|
||||
saveDone chan struct{} // 用于通知保存协程停止
|
||||
stopped bool // 服务器是否已经停止
|
||||
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
|
||||
queryLogsMutex sync.RWMutex
|
||||
queryLogs []QueryLog // 查询日志列表
|
||||
maxQueryLogs int // 最大保存日志数量
|
||||
saveTicker *time.Ticker // 用于定时保存数据
|
||||
startTime time.Time // 服务器启动时间
|
||||
saveDone chan struct{} // 用于通知保存协程停止
|
||||
stopped bool // 服务器是否已经停止
|
||||
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
|
||||
|
||||
// IP地理位置缓存
|
||||
ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射
|
||||
ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁
|
||||
ipGeolocationCacheTTL time.Duration // 缓存有效期
|
||||
|
||||
// DNS查询缓存
|
||||
dnsCache *DNSCache // DNS响应缓存
|
||||
}
|
||||
|
||||
// Stats DNS服务器统计信息
|
||||
@@ -97,6 +130,10 @@ type Stats struct {
|
||||
// NewServer 创建DNS服务器实例
|
||||
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// 从配置中读取DNS缓存TTL值(分钟)
|
||||
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
|
||||
|
||||
server := &Server{
|
||||
config: config,
|
||||
shieldConfig: shieldConfig,
|
||||
@@ -125,8 +162,15 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
|
||||
hourlyStats: make(map[string]int64),
|
||||
dailyStats: make(map[string]int64),
|
||||
monthlyStats: make(map[string]int64),
|
||||
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
|
||||
maxQueryLogs: 10000, // 最大保存10000条日志
|
||||
saveDone: make(chan struct{}),
|
||||
stopped: false, // 初始化为未停止状态
|
||||
// IP地理位置缓存初始化
|
||||
ipGeolocationCache: make(map[string]*IPGeolocation),
|
||||
ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时
|
||||
// DNS查询缓存初始化
|
||||
dnsCache: NewDNSCache(cacheTTL),
|
||||
}
|
||||
|
||||
// 加载已保存的统计数据
|
||||
@@ -232,8 +276,16 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// 获取来源IP
|
||||
sourceIP := w.RemoteAddr().String()
|
||||
// 提取IP地址部分,去掉端口
|
||||
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
|
||||
sourceIP = sourceIP[:idx]
|
||||
if strings.HasPrefix(sourceIP, "[") {
|
||||
// IPv6地址格式: [::1]:53
|
||||
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
|
||||
sourceIP = sourceIP[1:idx] // 去掉方括号
|
||||
}
|
||||
} else {
|
||||
// IPv4地址格式: 127.0.0.1:53
|
||||
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
|
||||
sourceIP = sourceIP[:idx]
|
||||
}
|
||||
}
|
||||
|
||||
// 更新来源IP统计
|
||||
@@ -246,28 +298,10 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// 更新客户端统计
|
||||
s.updateClientStats(sourceIP)
|
||||
|
||||
// 只处理递归查询
|
||||
if r.RecursionDesired == false {
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
response.SetRcode(r, dns.RcodeRefused)
|
||||
w.WriteMsg(response)
|
||||
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.TotalResponseTime += responseTime
|
||||
if stats.Queries > 0 {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询域名和类型
|
||||
var domain string
|
||||
var queryType string
|
||||
var qType uint16
|
||||
if len(r.Question) > 0 {
|
||||
domain = r.Question[0].Name
|
||||
// 移除末尾的点
|
||||
@@ -276,6 +310,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
// 获取查询类型
|
||||
queryType = dns.TypeToString[r.Question[0].Qtype]
|
||||
qType = r.Question[0].Qtype
|
||||
// 更新查询类型统计
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.QueryTypes[queryType]++
|
||||
@@ -284,22 +319,52 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
|
||||
|
||||
// 检查hosts文件是否有匹配
|
||||
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
|
||||
s.handleHostsResponse(w, r, ip)
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
// 只处理递归查询
|
||||
if r.RecursionDesired == false {
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
response.SetRcode(r, dns.RcodeRefused)
|
||||
w.WriteMsg(response)
|
||||
|
||||
// 缓存命中,响应时间设为0ms
|
||||
responseTime := int64(0)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.TotalResponseTime += responseTime
|
||||
if stats.Queries > 0 {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查hosts文件是否有匹配
|
||||
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
|
||||
s.handleHostsResponse(w, r, ip)
|
||||
// 缓存命中,响应时间设为0ms
|
||||
responseTime := int64(0)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.TotalResponseTime += responseTime
|
||||
if stats.Queries > 0 {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否被屏蔽
|
||||
if s.shieldManager.IsBlocked(domain) {
|
||||
// 获取屏蔽详情
|
||||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||||
blockRule, _ := blockDetails["blockRule"].(string)
|
||||
blockType, _ := blockDetails["blockRuleType"].(string)
|
||||
|
||||
s.handleBlockedResponse(w, r, domain)
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
@@ -309,19 +374,68 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发到上游DNS服务器
|
||||
s.forwardDNSRequest(w, r, domain)
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
// 检查缓存中是否有响应(增强版缓存查询)
|
||||
if cachedResponse, found := s.dnsCache.Get(r.Question[0].Name, qType); found {
|
||||
// 缓存命中,直接返回缓存的响应
|
||||
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
|
||||
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
|
||||
cachedResponseCopy.Compress = true
|
||||
w.WriteMsg(cachedResponseCopy)
|
||||
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.TotalResponseTime += responseTime
|
||||
if stats.Queries > 0 {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
|
||||
// 添加查询日志 - 标记为缓存
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true)
|
||||
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType)
|
||||
return
|
||||
}
|
||||
|
||||
// 缓存未命中,转发到上游DNS服务器
|
||||
response, rtt := s.forwardDNSRequestWithCache(r, domain)
|
||||
if response != nil {
|
||||
// 写入响应给客户端
|
||||
w.WriteMsg(response)
|
||||
}
|
||||
|
||||
// 使用上游服务器的实际响应时间(转换为毫秒)
|
||||
responseTime := int64(rtt.Milliseconds())
|
||||
// 如果rtt为0(查询失败),则使用本地计算的时间
|
||||
if responseTime == 0 {
|
||||
responseTime = time.Since(startTime).Milliseconds()
|
||||
}
|
||||
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.TotalResponseTime += responseTime
|
||||
if stats.Queries > 0 {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
|
||||
// 如果响应成功,缓存结果(增强版缓存存储)
|
||||
if response != nil && response.Rcode == dns.RcodeSuccess {
|
||||
// 创建响应副本以避免后续修改影响缓存
|
||||
responseCopy := response.Copy()
|
||||
// 设置合理的TTL,不超过默认的30分钟
|
||||
defaultCacheTTL := 30 * time.Minute
|
||||
s.dnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
|
||||
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL)
|
||||
}
|
||||
|
||||
// 添加查询日志 - 标记为实时
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
|
||||
}
|
||||
|
||||
// handleHostsResponse 处理hosts文件匹配的响应
|
||||
@@ -425,7 +539,8 @@ func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain
|
||||
}
|
||||
|
||||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||||
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
|
||||
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration) {
|
||||
// 尝试所有上游DNS服务器
|
||||
for _, upstream := range s.config.UpstreamDNS {
|
||||
response, rtt, err := s.resolver.Exchange(r, upstream)
|
||||
@@ -433,7 +548,6 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
|
||||
// 设置递归可用标志
|
||||
response.RecursionAvailable = true
|
||||
|
||||
w.WriteMsg(response)
|
||||
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
|
||||
|
||||
// 记录解析域名统计
|
||||
@@ -442,7 +556,7 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Allowed++
|
||||
})
|
||||
return
|
||||
return response, rtt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,12 +565,18 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
response.SetRcode(r, dns.RcodeServerFailure)
|
||||
w.WriteMsg(response)
|
||||
|
||||
logger.Error("DNS查询失败", "domain", domain)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Errors++
|
||||
})
|
||||
return response, 0
|
||||
}
|
||||
|
||||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||||
response, _ := s.forwardDNSRequestWithCache(r, domain)
|
||||
w.WriteMsg(response)
|
||||
}
|
||||
|
||||
// updateBlockedDomainStats 更新被屏蔽域名统计
|
||||
@@ -539,6 +659,38 @@ func (s *Server) updateStats(update func(*Stats)) {
|
||||
update(s.stats)
|
||||
}
|
||||
|
||||
// addQueryLog 添加查询日志
|
||||
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache bool) {
|
||||
// 获取IP地理位置
|
||||
location := s.getIpGeolocation(clientIP)
|
||||
|
||||
// 创建日志记录
|
||||
log := QueryLog{
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Location: location,
|
||||
Domain: domain,
|
||||
QueryType: queryType,
|
||||
ResponseTime: responseTime,
|
||||
Result: result,
|
||||
BlockRule: blockRule,
|
||||
BlockType: blockType,
|
||||
FromCache: fromCache,
|
||||
}
|
||||
|
||||
// 添加到日志列表
|
||||
s.queryLogsMutex.Lock()
|
||||
defer s.queryLogsMutex.Unlock()
|
||||
|
||||
// 插入到列表开头
|
||||
s.queryLogs = append([]QueryLog{log}, s.queryLogs...)
|
||||
|
||||
// 限制日志数量
|
||||
if len(s.queryLogs) > s.maxQueryLogs {
|
||||
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
|
||||
}
|
||||
}
|
||||
|
||||
// GetStartTime 获取服务器启动时间
|
||||
func (s *Server) GetStartTime() time.Time {
|
||||
return s.startTime
|
||||
@@ -576,6 +728,145 @@ func (s *Server) GetStats() *Stats {
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryLogs 获取查询日志
|
||||
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog {
|
||||
s.queryLogsMutex.RLock()
|
||||
defer s.queryLogsMutex.RUnlock()
|
||||
|
||||
// 确保偏移量和限制值合理
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 100 // 默认返回100条日志
|
||||
}
|
||||
|
||||
// 创建日志副本用于过滤和排序
|
||||
var logsCopy []QueryLog
|
||||
|
||||
// 先过滤日志
|
||||
for _, log := range s.queryLogs {
|
||||
// 应用结果过滤
|
||||
if resultFilter != "" && log.Result != resultFilter {
|
||||
continue
|
||||
}
|
||||
|
||||
// 应用搜索过滤
|
||||
if searchTerm != "" {
|
||||
// 搜索域名或客户端IP
|
||||
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
logsCopy = append(logsCopy, log)
|
||||
}
|
||||
|
||||
// 排序日志
|
||||
if sortField != "" {
|
||||
sort.Slice(logsCopy, func(i, j int) bool {
|
||||
var a, b interface{}
|
||||
switch sortField {
|
||||
case "time":
|
||||
a = logsCopy[i].Timestamp
|
||||
b = logsCopy[j].Timestamp
|
||||
case "clientIp":
|
||||
a = logsCopy[i].ClientIP
|
||||
b = logsCopy[j].ClientIP
|
||||
case "domain":
|
||||
a = logsCopy[i].Domain
|
||||
b = logsCopy[j].Domain
|
||||
case "responseTime":
|
||||
a = logsCopy[i].ResponseTime
|
||||
b = logsCopy[j].ResponseTime
|
||||
case "blockRule":
|
||||
a = logsCopy[i].BlockRule
|
||||
b = logsCopy[j].BlockRule
|
||||
default:
|
||||
// 默认按时间排序
|
||||
a = logsCopy[i].Timestamp
|
||||
b = logsCopy[j].Timestamp
|
||||
}
|
||||
|
||||
// 根据排序方向比较
|
||||
if sortDirection == "asc" {
|
||||
return compareValues(a, b) < 0
|
||||
}
|
||||
return compareValues(a, b) > 0
|
||||
})
|
||||
}
|
||||
|
||||
// 计算返回范围
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if end > len(logsCopy) {
|
||||
end = len(logsCopy)
|
||||
}
|
||||
if start >= len(logsCopy) {
|
||||
return []QueryLog{} // 没有数据,返回空切片
|
||||
}
|
||||
|
||||
return logsCopy[start:end]
|
||||
}
|
||||
|
||||
// compareValues 比较两个值
|
||||
func compareValues(a, b interface{}) int {
|
||||
switch v1 := a.(type) {
|
||||
case time.Time:
|
||||
v2 := b.(time.Time)
|
||||
if v1.Before(v2) {
|
||||
return -1
|
||||
}
|
||||
if v1.After(v2) {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case string:
|
||||
v2 := b.(string)
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
}
|
||||
if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case int64:
|
||||
v2 := b.(int64)
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
}
|
||||
if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryLogsCount 获取查询日志总数
|
||||
func (s *Server) GetQueryLogsCount() int {
|
||||
s.queryLogsMutex.RLock()
|
||||
defer s.queryLogsMutex.RUnlock()
|
||||
return len(s.queryLogs)
|
||||
}
|
||||
|
||||
// GetQueryStats 获取查询统计信息
|
||||
func (s *Server) GetQueryStats() map[string]interface{} {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
|
||||
// 计算统计数据
|
||||
return map[string]interface{}{
|
||||
"totalQueries": s.stats.Queries,
|
||||
"blockedQueries": s.stats.Blocked,
|
||||
"allowedQueries": s.stats.Allowed,
|
||||
"errorQueries": s.stats.Errors,
|
||||
"avgResponseTime": s.stats.AvgResponseTime,
|
||||
"activeIPs": len(s.stats.SourceIPs),
|
||||
}
|
||||
}
|
||||
|
||||
// GetTopBlockedDomains 获取TOP屏蔽域名列表
|
||||
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
|
||||
s.blockedDomainsMutex.RLock()
|
||||
@@ -707,6 +998,132 @@ func (s *Server) GetMonthlyStats() map[string]int64 {
|
||||
return result
|
||||
}
|
||||
|
||||
// isPrivateIP 检测IP地址是否为内网IP
|
||||
func isPrivateIP(ip string) bool {
|
||||
// 解析IP地址
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查IPv4内网地址
|
||||
if ipv4 := parsedIP.To4(); ipv4 != nil {
|
||||
// 10.0.0.0/8
|
||||
if ipv4[0] == 10 {
|
||||
return true
|
||||
}
|
||||
// 172.16.0.0/12
|
||||
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
|
||||
return true
|
||||
}
|
||||
// 192.168.0.0/16
|
||||
if ipv4[0] == 192 && ipv4[1] == 168 {
|
||||
return true
|
||||
}
|
||||
// 127.0.0.0/8 (localhost)
|
||||
if ipv4[0] == 127 {
|
||||
return true
|
||||
}
|
||||
// 169.254.0.0/16 (链路本地地址)
|
||||
if ipv4[0] == 169 && ipv4[1] == 254 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查IPv6内网地址
|
||||
// ::1/128 (localhost)
|
||||
if parsedIP.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
// fc00::/7 (唯一本地地址)
|
||||
if parsedIP[0]&0xfc == 0xfc {
|
||||
return true
|
||||
}
|
||||
// fe80::/10 (链路本地地址)
|
||||
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getIpGeolocation 获取IP地址的地理位置信息
|
||||
func (s *Server) getIpGeolocation(ip string) string {
|
||||
// 检查IP是否为本地或内网地址
|
||||
if isPrivateIP(ip) {
|
||||
return "内网 内网"
|
||||
}
|
||||
|
||||
// 先检查缓存
|
||||
s.ipGeolocationCacheMutex.RLock()
|
||||
geo, exists := s.ipGeolocationCache[ip]
|
||||
s.ipGeolocationCacheMutex.RUnlock()
|
||||
|
||||
// 如果缓存存在且未过期,直接返回
|
||||
if exists && time.Now().Before(geo.Expiry) {
|
||||
return fmt.Sprintf("%s %s", geo.Country, geo.City)
|
||||
}
|
||||
|
||||
// 缓存不存在或已过期,从API获取
|
||||
geoInfo, err := s.fetchIpGeolocationFromAPI(ip)
|
||||
if err != nil {
|
||||
logger.Error("获取IP地理位置失败", "ip", ip, "error", err)
|
||||
return "未知 未知"
|
||||
}
|
||||
|
||||
// 保存到缓存
|
||||
s.ipGeolocationCacheMutex.Lock()
|
||||
s.ipGeolocationCache[ip] = &IPGeolocation{
|
||||
Country: geoInfo["country"].(string),
|
||||
City: geoInfo["city"].(string),
|
||||
Expiry: time.Now().Add(s.ipGeolocationCacheTTL),
|
||||
}
|
||||
s.ipGeolocationCacheMutex.Unlock()
|
||||
|
||||
// 返回格式化的地理位置
|
||||
return fmt.Sprintf("%s %s", geoInfo["country"].(string), geoInfo["city"].(string))
|
||||
}
|
||||
|
||||
// fetchIpGeolocationFromAPI 从第三方API获取IP地理位置信息
|
||||
func (s *Server) fetchIpGeolocationFromAPI(ip string) (map[string]interface{}, error) {
|
||||
// 使用ip-api.com获取IP地理位置信息
|
||||
url := fmt.Sprintf("http://ip-api.com/json/%s?fields=country,city", ip)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应内容
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析JSON响应
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查API返回状态
|
||||
status, ok := result["status"].(string)
|
||||
if !ok || status != "success" {
|
||||
return nil, fmt.Errorf("API返回错误状态: %v", result)
|
||||
}
|
||||
|
||||
// 确保国家和城市字段存在
|
||||
if _, ok := result["country"]; !ok {
|
||||
result["country"] = "未知"
|
||||
}
|
||||
if _, ok := result["city"]; !ok {
|
||||
result["city"] = "未知"
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// loadStatsData 从文件加载统计数据
|
||||
func (s *Server) loadStatsData() {
|
||||
if s.config.StatsFile == "" {
|
||||
@@ -774,6 +1191,58 @@ func (s *Server) loadStatsData() {
|
||||
s.clientStatsMutex.Unlock()
|
||||
|
||||
logger.Info("统计数据加载成功")
|
||||
|
||||
// 加载查询日志
|
||||
s.loadQueryLogs()
|
||||
}
|
||||
|
||||
// loadQueryLogs 从文件加载查询日志
|
||||
func (s *Server) loadQueryLogs() {
|
||||
if s.config.StatsFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取绝对路径
|
||||
statsFilePath, err := filepath.Abs(s.config.StatsFile)
|
||||
if err != nil {
|
||||
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建查询日志文件路径
|
||||
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
|
||||
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
data, err := ioutil.ReadFile(queryLogPath)
|
||||
if err != nil {
|
||||
logger.Error("读取查询日志文件失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析数据
|
||||
var logs []QueryLog
|
||||
err = json.Unmarshal(data, &logs)
|
||||
if err != nil {
|
||||
logger.Error("解析查询日志失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新查询日志
|
||||
s.queryLogsMutex.Lock()
|
||||
s.queryLogs = logs
|
||||
// 确保日志数量不超过限制
|
||||
if len(s.queryLogs) > s.maxQueryLogs {
|
||||
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
|
||||
}
|
||||
s.queryLogsMutex.Unlock()
|
||||
|
||||
logger.Info("查询日志加载成功", "count", len(logs))
|
||||
}
|
||||
|
||||
// saveStatsData 保存统计数据到文件
|
||||
@@ -862,6 +1331,37 @@ func (s *Server) saveStatsData() {
|
||||
}
|
||||
|
||||
logger.Info("统计数据保存成功", "file", statsFilePath)
|
||||
|
||||
// 保存查询日志到文件
|
||||
s.saveQueryLogs(statsDir)
|
||||
}
|
||||
|
||||
// saveQueryLogs 保存查询日志到文件
|
||||
func (s *Server) saveQueryLogs(dataDir string) {
|
||||
// 构建查询日志文件路径
|
||||
queryLogPath := filepath.Join(dataDir, "querylog.json")
|
||||
|
||||
// 获取查询日志数据
|
||||
s.queryLogsMutex.RLock()
|
||||
logsCopy := make([]QueryLog, len(s.queryLogs))
|
||||
copy(logsCopy, s.queryLogs)
|
||||
s.queryLogsMutex.RUnlock()
|
||||
|
||||
// 序列化数据
|
||||
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
|
||||
if err != nil {
|
||||
logger.Error("序列化查询日志失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
err = os.WriteFile(queryLogPath, jsonData, 0644)
|
||||
if err != nil {
|
||||
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("查询日志保存成功", "file", queryLogPath)
|
||||
}
|
||||
|
||||
// startCpuUsageMonitor 启动CPU使用率监控
|
||||
|
||||
Reference in New Issue
Block a user