多项更新优化

This commit is contained in:
Alex Yang
2025-12-26 09:02:59 +08:00
parent 356310ae75
commit b48dc4ed27
18 changed files with 1178 additions and 348 deletions

View File

@@ -22,6 +22,17 @@ import (
"github.com/miekg/dns"
)
// 确保DNS服务器地址包含端口号默认添加53端口
func normalizeDNSServerAddress(address string) string {
// 检查地址是否已经包含端口号
if _, _, err := net.SplitHostPort(address); err != nil {
// 如果没有端口号添加默认的53端口
return net.JoinHostPort(address, "53")
}
// 已经有端口号,直接返回
return address
}
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
@@ -45,22 +56,31 @@ type IPGeolocation struct {
Expiry time.Time `json:"expiry"` // 缓存过期时间
}
// DNSAnswer DNS解析记录
type DNSAnswer struct {
Type string `json:"type"` // 记录类型
Value string `json:"value"` // 记录值
TTL uint32 `json:"ttl"` // 生存时间
}
// QueryLog 查询日志记录
type QueryLog struct {
Timestamp time.Time // 查询时间
ClientIP string // 客户端IP
Location string // IP地理位置国家 城市)
Domain string // 查询域名
QueryType string // 查询类型
ResponseTime int64 // 响应时间(ms)
Result string // 查询结果allowed, blocked, error
BlockRule string // 屏蔽规则(如果被屏蔽)
BlockType string // 屏蔽类型(如果被屏蔽)
FromCache bool // 是否来自缓存
DNSSEC bool // 是否使用了DNSSEC
EDNS bool // 是否使用了EDNS
DNSServer string // 使用的DNS服务器
DNSSECServer string // 使用的DNSSEC专用服务器
Timestamp time.Time `json:"timestamp"` // 查询时间
ClientIP string `json:"clientIP"` // 客户端IP
Location string `json:"location"` // IP地理位置国家 城市)
Domain string `json:"domain"` // 查询域名
QueryType string `json:"queryType"` // 查询类型
ResponseTime int64 `json:"responseTime"` // 响应时间(ms)
Result string `json:"result"` // 查询结果allowed, blocked, error
BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽)
BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽)
FromCache bool `json:"fromCache"` // 是否来自缓存
DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC
EDNS bool `json:"edns"` // 是否使用了EDNS
DNSServer string `json:"dnsServer"` // 使用的DNS服务器
DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器
Answers []DNSAnswer `json:"answers"` // 解析记录
ResponseCode int `json:"responseCode"` // DNS响应代码
}
// StatsData 用于持久化的统计数据结构
@@ -348,6 +368,29 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
})
// 检查是否是AAAA记录查询且IPv6解析已禁用
if qType == dns.TypeAAAA && !s.config.EnableIPv6 {
// 返回NXDOMAIN响应域名不存在
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(response)
// 更新统计信息
responseTime := int64(0)
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeNameError)
logger.Debug("IPv6解析已禁用拒绝AAAA记录查询", "domain", domain)
return
}
}
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
@@ -370,7 +413,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "")
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused)
return
}
@@ -386,8 +429,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, false, true, "缓存", "无")
// 该方法内部未直接调用addQueryLog而是在handleDNSRequest中处理
return
}
@@ -408,8 +450,16 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true, "无", "无")
// 添加查询日志 - 被屏蔽域名
blockedAnswers := []DNSAnswer{}
// 根据屏蔽方法确定响应代码
blockedRcode := dns.RcodeNameError // 默认NXDOMAIN
if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" {
blockedRcode = dns.RcodeRefused
} else if blockMethod == "emptyIP" || blockMethod == "customIP" {
blockedRcode = dns.RcodeSuccess
}
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode)
return
}
@@ -481,8 +531,25 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
}
// 从缓存响应中提取解析记录
cachedAnswers := []DNSAnswer{}
if cachedResponse != nil {
for _, rr := range cachedResponse.Answer {
cachedAnswers = append(cachedAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 添加查询日志 - 标记为缓存
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无")
// 从缓存响应中获取响应代码
cacheRcode := dns.RcodeSuccess // 默认成功
if cachedResponse != nil {
cacheRcode = cachedResponse.Rcode
}
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode)
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType, "dnssec", cachedDNSSEC)
return
}
@@ -566,8 +633,25 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
}
// 从响应中提取解析记录
responseAnswers := []DNSAnswer{}
if response != nil {
for _, rr := range response.Answer {
responseAnswers = append(responseAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 添加查询日志 - 标记为实时
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true, dnsServer, dnssecServer)
// 从响应中获取响应代码
realRcode := dns.RcodeSuccess // 默认成功
if response != nil {
realRcode = response.Rcode
}
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode)
}
// handleHostsResponse 处理hosts文件匹配的响应
@@ -731,14 +815,35 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 2. 如果没有匹配的域名特定配置
if !domainMatched {
// 如果启用了DNSSEC且有配置DNSSEC专用服务器并且域名不匹配NoDNSSECDomains则使用DNSSEC专用服务器
// 创建一个新的切片来存储最终的上游服务器列表
var finalUpstreamDNS []string
// 首先添加用户配置的上游DNS服务器
finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...)
logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS)
// 如果启用了DNSSEC且有配置DNSSEC专用服务器并且域名不匹配NoDNSSECDomains则将DNSSEC专用服务器添加到列表中
if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC {
selectedUpstreamDNS = s.config.DNSSECUpstreamDNS
logger.Debug("使用DNSSEC专用服务器", "servers", selectedUpstreamDNS)
} else {
// 否则使用默认的上游DNS服务器
selectedUpstreamDNS = s.config.UpstreamDNS
// 合并DNSSEC专用服务器到上游服务器列表避免重复并确保包含端口号
for _, dnssecServer := range s.config.DNSSECUpstreamDNS {
hasDuplicate := false
// 确保DNSSEC服务器地址包含端口号
normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer)
for _, upstream := range finalUpstreamDNS {
if upstream == normalizedDnssecServer {
hasDuplicate = true
break
}
}
if !hasDuplicate {
finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer)
}
}
logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS)
}
// 使用最终合并后的服务器列表
selectedUpstreamDNS = finalUpstreamDNS
}
// 1. 首先尝试所有配置的上游DNS服务器
@@ -769,8 +874,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
go func(server string) {
defer wg.Done()
// 发送请求并获取响应
response, rtt, err := s.resolver.Exchange(r, server)
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server))
select {
case responses <- serverResponse{response, rtt, server, err}:
@@ -825,55 +930,103 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
resp.response.AuthenticatedData = false
}
// 如果响应成功或为NXDOMAIN根据DNSSEC状态选择最佳响应
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
// 检查当前使用的服务器是否是DNSSEC专用服务器
for _, dnssecServer := range dnssecServers {
if dnssecServer == resp.server {
usedDNSSECServer = resp.server
break
}
// 检查当前服务器是否是DNSSEC专用服务器
for _, dnssecServer := range dnssecServers {
if dnssecServer == resp.server {
usedDNSSECServer = resp.server
break
}
}
if resp.response.Rcode == dns.RcodeSuccess {
// 处理成功响应
// 优先选择带有DNSSEC记录的响应
if containsDNSSEC {
// 检查当前服务器是否是用户配置的上游DNS服务器
isUserUpstream := false
for _, userServer := range s.config.UpstreamDNS {
if userServer == resp.server {
isUserUpstream = true
break
}
}
// 处理响应优先选择用户配置的主DNS服务器
if resp.response.Rcode == dns.RcodeSuccess {
// 成功响应,优先使用
if isUserUpstream {
// 用户配置的主DNS服务器响应直接设置为最佳响应
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
hasDNSSECResponse = containsDNSSEC
usedDNSServer = resp.server
logger.Debug("使用用户配置的上游服务器响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
} else if containsDNSSEC {
// 非用户配置服务器但有DNSSEC记录
if !hasBestResponse || !isUserUpstream {
// 如果还没有最佳响应,或者当前最佳响应不是用户配置的服务器,则更新
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
hasDNSSECResponse = true
usedDNSServer = resp.server
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
} else if !hasBestResponse {
// 没有带DNSSEC的响应时保存第一个成功响应
}
} else {
// 非用户配置服务器没有DNSSEC记录
if !hasBestResponse {
// 如果还没有最佳响应,设置为最佳响应
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
usedDNSServer = resp.server
logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
}
} else if resp.response.Rcode == dns.RcodeNameError {
// 处理NXDOMAIN响应
// 如果还没有最佳响应或者最佳响应也是NXDOMAIN优先选择更快的NXDOMAIN响应
if !hasBestResponse || bestResponse.Rcode == dns.RcodeNameError {
// 如果还没有最佳响应,或者当前响应更快,更新最佳响应
if !hasBestResponse || resp.rtt < bestRtt {
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
usedDNSServer = resp.server
logger.Debug("找到NXDOMAIN最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
}
}
} else if resp.response.Rcode == dns.RcodeNameError {
// NXDOMAIN响应
if !hasBestResponse || bestResponse.Rcode == dns.RcodeNameError {
// 如果还没有最佳响应,或者最佳响应也是NXDOMAIN
if isUserUpstream {
// 用户配置的服务器,直接使用
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
usedDNSServer = resp.server
logger.Debug("使用用户配置的上游服务器NXDOMAIN响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
} else if !hasBestResponse || resp.rtt < bestRtt {
// 非用户配置服务器,选择更快的响应
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
usedDNSServer = resp.server
logger.Debug("找到NXDOMAIN最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
}
}
// 保存为备选响应
}
// 更新备选响应,确保总有一个可用的响应
if resp.response != nil {
if !hasBackup {
// 第一次保存备选响应
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
} else {
// 后续响应,优先保存用户配置的服务器响应作为备选
if isUserUpstream {
backupResponse = resp.response
backupRtt = resp.rtt
}
}
}
// 即使响应不是成功或NXDOMAIN也保存为最佳响应如果还没有的话
// 确保总有一个响应返回给客户端
if !hasBestResponse {
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
usedDNSServer = resp.server
logger.Debug("使用非成功响应作为最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt, "rcode", resp.response.Rcode)
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
@@ -882,9 +1035,32 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
case "loadbalance":
// 负载均衡模式 - 使用加权随机选择算法
// 1. 选择一个加权随机服务器
selectedServer := s.selectWeightedRandomServer(selectedUpstreamDNS)
if selectedServer != "" {
// 1. 尝试所有可用的服务器,直到找到一个能正常工作的
var triedServers []string
for len(triedServers) < len(selectedUpstreamDNS) {
// 从剩余的服务器中选择一个加权随机服务器
var availableServers []string
for _, server := range selectedUpstreamDNS {
found := false
for _, tried := range triedServers {
if server == tried {
found = true
break
}
}
if !found {
availableServers = append(availableServers, server)
}
}
selectedServer := s.selectWeightedRandomServer(availableServers)
if selectedServer == "" {
break
}
triedServers = append(triedServers, selectedServer)
logger.Debug("在负载均衡模式下选择服务器", "domain", domain, "server", selectedServer, "triedServers", triedServers)
// 设置超时上下文
timeoutCtx, cancel := context.WithTimeout(s.ctx, time.Duration(s.config.Timeout)*time.Millisecond)
defer cancel()
@@ -897,7 +1073,7 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}, 1)
go func() {
response, rtt, err := s.resolver.Exchange(r, selectedServer)
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(selectedServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
@@ -997,10 +1173,12 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
backupRtt = rtt
hasBackup = true
}
break // 找到有效响应,退出循环
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedServer, false, 0)
logger.Debug("服务器请求失败,尝试下一个", "domain", domain, "server", selectedServer, "error", err)
}
}
@@ -1021,7 +1199,7 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}, 1)
go func() {
resp, r, e := s.resolver.Exchange(r, fastestServer)
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(fastestServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
@@ -1143,7 +1321,7 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
defer wg.Done()
// 发送请求并获取响应
response, rtt, err := s.resolver.Exchange(r, server)
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server))
select {
case responses <- serverResponse{response, rtt, server, err}:
@@ -1284,7 +1462,7 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}, 1)
go func() {
response, rtt, err := s.resolver.Exchange(r, selectedDnssecServer)
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
@@ -1382,7 +1560,7 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}, 1)
go func() {
resp, r, e := s.resolver.Exchange(r, localServer)
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(localServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
@@ -1967,7 +2145,7 @@ func (s *Server) updateStats(update func(*Stats)) {
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string) {
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) {
// 获取IP地理位置
location := s.getIpGeolocation(clientIP)
@@ -1987,6 +2165,8 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
EDNS: edns,
DNSServer: dnsServer,
DNSSECServer: dnssecServer,
Answers: answers,
ResponseCode: responseCode,
}
// 添加到日志列表
@@ -2441,12 +2621,8 @@ func (s *Server) fetchIpGeolocationFromAPI(ip string) (map[string]interface{}, e
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
if s.config.StatsFile == "" {
return
}
// 检查文件是否存在
data, err := ioutil.ReadFile(s.config.StatsFile)
data, err := ioutil.ReadFile("data/stats.json")
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
@@ -2515,14 +2691,10 @@ func (s *Server) loadStatsData() {
// loadQueryLogs 从文件加载查询日志
func (s *Server) loadQueryLogs() {
if s.config.StatsFile == "" {
return
}
// 获取绝对路径
statsFilePath, err := filepath.Abs(s.config.StatsFile)
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
@@ -2564,14 +2736,10 @@ func (s *Server) loadQueryLogs() {
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
if s.config.StatsFile == "" {
return
}
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs(s.config.StatsFile)
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
@@ -2754,15 +2922,15 @@ func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.StatsFile == "" || s.config.SaveInterval <= 0 {
if s.config.SaveInterval <= 0 {
return
}
// 设置定时器
// 初始化定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", s.config.StatsFile)
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json")
// 定期保存数据
for {