package dns import ( "context" "fmt" "net" "sort" "sync" "time" "dns-server/config" "dns-server/logger" "dns-server/shield" "github.com/miekg/dns" ) // BlockedDomain 屏蔽域名统计 type BlockedDomain struct { Domain string Count int64 LastSeen time.Time } // Server DNS服务器 type Server struct { config *config.DNSConfig shieldConfig *config.ShieldConfig shieldManager *shield.ShieldManager server *dns.Server resolver *dns.Client ctx context.Context cancel context.CancelFunc statsMutex sync.Mutex stats *Stats blockedDomainsMutex sync.RWMutex blockedDomains map[string]*BlockedDomain resolvedDomainsMutex sync.RWMutex resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名 hourlyStatsMutex sync.RWMutex hourlyStats map[string]int64 // 按小时统计屏蔽数量 } // Stats DNS服务器统计信息 type Stats struct { Queries int64 Blocked int64 Allowed int64 Errors int64 LastQuery time.Time } // NewServer 创建DNS服务器实例 func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server { ctx, cancel := context.WithCancel(context.Background()) return &Server{ config: config, shieldConfig: shieldConfig, shieldManager: shieldManager, resolver: &dns.Client{ Net: "udp", Timeout: time.Duration(config.Timeout) * time.Millisecond, }, ctx: ctx, cancel: cancel, stats: &Stats{ Queries: 0, Blocked: 0, Allowed: 0, Errors: 0, }, blockedDomains: make(map[string]*BlockedDomain), resolvedDomains: make(map[string]*BlockedDomain), hourlyStats: make(map[string]int64), } } // Start 启动DNS服务器 func (s *Server) Start() error { s.server = &dns.Server{ Addr: fmt.Sprintf(":%d", s.config.Port), Net: "udp", Handler: dns.HandlerFunc(s.handleDNSRequest), } // 启动TCP服务器(用于大型响应) tcpServer := &dns.Server{ Addr: fmt.Sprintf(":%d", s.config.Port), Net: "tcp", Handler: dns.HandlerFunc(s.handleDNSRequest), } // 启动UDP服务 go func() { logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port)) if err := s.server.ListenAndServe(); err != nil { logger.Error("DNS UDP服务器启动失败", "error", err) s.cancel() } }() // 启动TCP服务 go func() { logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port)) if err := tcpServer.ListenAndServe(); err != nil { logger.Error("DNS TCP服务器启动失败", "error", err) s.cancel() } }() // 等待停止信号 <-s.ctx.Done() return nil } // Stop 停止DNS服务器 func (s *Server) Stop() { if s.server != nil { s.server.Shutdown() } s.cancel() logger.Info("DNS服务器已停止") } // handleDNSRequest 处理DNS请求 func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { s.updateStats(func(stats *Stats) { stats.Queries++ stats.LastQuery = time.Now() }) // 只处理递归查询 if r.RecursionDesired == false { response := new(dns.Msg) response.SetReply(r) response.RecursionAvailable = true response.SetRcode(r, dns.RcodeRefused) w.WriteMsg(response) return } // 获取查询域名 var domain string if len(r.Question) > 0 { domain = r.Question[0].Name // 移除末尾的点 if len(domain) > 0 && domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } } logger.Debug("接收到DNS查询", "domain", domain, "type", r.Question[0].Qtype, "client", w.RemoteAddr()) // 检查hosts文件是否有匹配 if ip, exists := s.shieldManager.GetHostsIP(domain); exists { s.handleHostsResponse(w, r, ip) return } // 检查是否被屏蔽 if s.shieldManager.IsBlocked(domain) { s.handleBlockedResponse(w, r, domain) return } // 转发到上游DNS服务器 s.forwardDNSRequest(w, r, domain) } // handleHostsResponse 处理hosts文件匹配的响应 func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) { response := new(dns.Msg) response.SetReply(r) response.RecursionAvailable = true if len(r.Question) > 0 { q := r.Question[0] answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP(ip) response.Answer = append(response.Answer, answer) } // 记录解析域名统计 domain := "" if len(r.Question) > 0 { domain = r.Question[0].Name if len(domain) > 0 && domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } s.updateResolvedDomainStats(domain) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Allowed++ }) } // handleBlockedResponse 处理被屏蔽的域名响应 func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) { logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr()) // 更新被屏蔽域名统计 s.updateBlockedDomainStats(domain) // 更新总体统计 s.updateStats(func(stats *Stats) { stats.Blocked++ }) response := new(dns.Msg) response.SetReply(r) response.RecursionAvailable = true // 获取屏蔽方法配置 blockMethod := "NXDOMAIN" // 默认值 customBlockIP := "" // 默认值 // 从Server结构体的shieldConfig字段获取配置 if s.shieldConfig != nil { blockMethod = s.shieldConfig.BlockMethod customBlockIP = s.shieldConfig.CustomBlockIP } // 根据屏蔽方法返回不同的响应 switch blockMethod { case "refused": // 返回拒绝查询响应 response.SetRcode(r, dns.RcodeRefused) case "emptyIP": // 返回空IP响应 if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA { answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP("0.0.0.0") // 空IP response.Answer = append(response.Answer, answer) } case "customIP": // 返回自定义IP响应 if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" { answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP(customBlockIP) response.Answer = append(response.Answer, answer) } case "NXDOMAIN", "": fallthrough // 默认使用NXDOMAIN default: // 返回NXDOMAIN响应(域名不存在) response.SetRcode(r, dns.RcodeNameError) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Blocked++ }) } // forwardDNSRequest 转发DNS请求到上游服务器 func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) { // 尝试所有上游DNS服务器 for _, upstream := range s.config.UpstreamDNS { response, rtt, err := s.resolver.Exchange(r, upstream) if err == nil && response != nil && response.Rcode == dns.RcodeSuccess { // 设置递归可用标志 response.RecursionAvailable = true w.WriteMsg(response) logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream) // 记录解析域名统计 s.updateResolvedDomainStats(domain) s.updateStats(func(stats *Stats) { stats.Allowed++ }) return } } // 所有上游服务器都失败,返回服务器失败错误 response := new(dns.Msg) response.SetReply(r) response.RecursionAvailable = true response.SetRcode(r, dns.RcodeServerFailure) w.WriteMsg(response) logger.Error("DNS查询失败", "domain", domain) s.updateStats(func(stats *Stats) { stats.Errors++ }) } // updateBlockedDomainStats 更新被屏蔽域名统计 func (s *Server) updateBlockedDomainStats(domain string) { // 更新被屏蔽域名计数 s.blockedDomainsMutex.Lock() defer s.blockedDomainsMutex.Unlock() if entry, exists := s.blockedDomains[domain]; exists { entry.Count++ entry.LastSeen = time.Now() } else { s.blockedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now(), } } // 更新小时统计 hourKey := time.Now().Format("2006-01-02-15") s.hourlyStatsMutex.Lock() s.hourlyStats[hourKey]++ s.hourlyStatsMutex.Unlock() } // updateResolvedDomainStats 更新解析域名统计 func (s *Server) updateResolvedDomainStats(domain string) { s.resolvedDomainsMutex.Lock() defer s.resolvedDomainsMutex.Unlock() if entry, exists := s.resolvedDomains[domain]; exists { entry.Count++ entry.LastSeen = time.Now() } else { s.resolvedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now(), } } } // updateStats 更新统计信息 func (s *Server) updateStats(update func(*Stats)) { s.statsMutex.Lock() defer s.statsMutex.Unlock() update(s.stats) } // GetStats 获取DNS服务器统计信息 func (s *Server) GetStats() *Stats { s.statsMutex.Lock() defer s.statsMutex.Unlock() // 返回统计信息的副本 return &Stats{ Queries: s.stats.Queries, Blocked: s.stats.Blocked, Allowed: s.stats.Allowed, Errors: s.stats.Errors, LastQuery: s.stats.LastQuery, } } // GetTopBlockedDomains 获取TOP屏蔽域名列表 func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain { s.blockedDomainsMutex.RLock() defer s.blockedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.blockedDomains)) for _, entry := range s.blockedDomains { domains = append(domains, *entry) } // 按计数排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetTopResolvedDomains 获取TOP解析域名 func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain { s.resolvedDomainsMutex.RLock() defer s.resolvedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.resolvedDomains)) for _, entry := range s.resolvedDomains { domains = append(domains, *entry) } // 按数量排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetRecentBlockedDomains 获取最近屏蔽的域名列表 func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain { s.blockedDomainsMutex.RLock() defer s.blockedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.blockedDomains)) for _, entry := range s.blockedDomains { domains = append(domains, *entry) } // 按时间排序 sort.Slice(domains, func(i, j int) bool { return domains[i].LastSeen.After(domains[j].LastSeen) }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetHourlyStats 获取24小时屏蔽统计 func (s *Server) GetHourlyStats() map[string]int64 { s.hourlyStatsMutex.RLock() defer s.hourlyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.hourlyStats { result[k] = v } return result }