新建DNS服务器

This commit is contained in:
Alex Yang
2025-11-23 18:21:29 +08:00
commit 0072e8a5c2
15 changed files with 5372 additions and 0 deletions

396
dns/server.go Normal file
View File

@@ -0,0 +1,396 @@
package dns
import (
"context"
"fmt"
"net"
"sort"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
"dns-server/shield"
"github.com/miekg/dns"
)
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen time.Time
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
server *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
resolver: &dns.Client{
Net: "udp",
Timeout: time.Duration(config.Timeout) * time.Millisecond,
},
ctx: ctx,
cancel: cancel,
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
},
blockedDomains: make(map[string]*BlockedDomain),
hourlyStats: make(map[string]int64),
}
}
// Start 启动DNS服务器
func (s *Server) Start() error {
s.server = &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动TCP服务器用于大型响应
tcpServer := &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动监听端口: %d", s.config.Port))
if err := tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
if s.server != nil {
s.server.Shutdown()
}
s.cancel()
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
})
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
return
}
// 获取查询域名
var domain string
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
}
logger.Debug("接收到DNS查询", "domain", domain, "type", r.Question[0].Qtype, "client", w.RemoteAddr())
// 检查hosts文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
s.handleHostsResponse(w, r, ip)
return
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(domain) {
s.handleBlockedResponse(w, r, domain)
return
}
// 转发到上游DNS服务器
s.forwardDNSRequest(w, r, domain)
}
// handleHostsResponse 处理hosts文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
// 更新总体统计
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(customBlockIP)
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应域名不存在
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
// 尝试所有上游DNS服务器
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
if err == nil && response != nil && response.Rcode == dns.RcodeSuccess {
// 设置递归可用标志
response.RecursionAvailable = true
w.WriteMsg(response)
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return
}
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(response)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 更新被屏蔽域名计数
s.blockedDomainsMutex.Lock()
defer s.blockedDomainsMutex.Unlock()
if entry, exists := s.blockedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
// 更新小时统计
hourKey := time.Now().Format("2006-01-02-15")
s.hourlyStats[hourKey]++
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen.After(domains[j].LastSeen)
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetHourlyStats 获取24小时屏蔽统计
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}