Files
dns-server/dns/server.go
2025-11-23 18:37:24 +08:00

455 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"context"
"fmt"
"net"
"sort"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
"dns-server/shield"
"github.com/miekg/dns"
)
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen time.Time
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
server *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
resolvedDomainsMutex sync.RWMutex
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
resolver: &dns.Client{
Net: "udp",
Timeout: time.Duration(config.Timeout) * time.Millisecond,
},
ctx: ctx,
cancel: cancel,
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
hourlyStats: make(map[string]int64),
}
}
// Start 启动DNS服务器
func (s *Server) Start() error {
s.server = &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动TCP服务器用于大型响应
tcpServer := &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动监听端口: %d", s.config.Port))
if err := tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
if s.server != nil {
s.server.Shutdown()
}
s.cancel()
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
})
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
return
}
// 获取查询域名
var domain string
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
}
logger.Debug("接收到DNS查询", "domain", domain, "type", r.Question[0].Qtype, "client", w.RemoteAddr())
// 检查hosts文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
s.handleHostsResponse(w, r, ip)
return
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(domain) {
s.handleBlockedResponse(w, r, domain)
return
}
// 转发到上游DNS服务器
s.forwardDNSRequest(w, r, domain)
}
// handleHostsResponse 处理hosts文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
// 记录解析域名统计
domain := ""
if len(r.Question) > 0 {
domain = r.Question[0].Name
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
s.updateResolvedDomainStats(domain)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
// 更新总体统计
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(customBlockIP)
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应域名不存在
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
// 尝试所有上游DNS服务器
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
if err == nil && response != nil && response.Rcode == dns.RcodeSuccess {
// 设置递归可用标志
response.RecursionAvailable = true
w.WriteMsg(response)
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return
}
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(response)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 更新被屏蔽域名计数
s.blockedDomainsMutex.Lock()
defer s.blockedDomainsMutex.Unlock()
if entry, exists := s.blockedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
// 更新小时统计
hourKey := time.Now().Format("2006-01-02-15")
s.hourlyStatsMutex.Lock()
s.hourlyStats[hourKey]++
s.hourlyStatsMutex.Unlock()
}
// updateResolvedDomainStats 更新解析域名统计
func (s *Server) updateResolvedDomainStats(domain string) {
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
if entry, exists := s.resolvedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopResolvedDomains 获取TOP解析域名
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
s.resolvedDomainsMutex.RLock()
defer s.resolvedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
for _, entry := range s.resolvedDomains {
domains = append(domains, *entry)
}
// 按数量排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen.After(domains[j].LastSeen)
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetHourlyStats 获取24小时屏蔽统计
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}