新建DNS服务器
This commit is contained in:
396
dns/server.go
Normal file
396
dns/server.go
Normal file
@@ -0,0 +1,396 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dns-server/config"
|
||||
"dns-server/logger"
|
||||
"dns-server/shield"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// BlockedDomain 屏蔽域名统计
|
||||
|
||||
type BlockedDomain struct {
|
||||
Domain string
|
||||
Count int64
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// Server DNS服务器
|
||||
type Server struct {
|
||||
config *config.DNSConfig
|
||||
shieldConfig *config.ShieldConfig
|
||||
shieldManager *shield.ShieldManager
|
||||
server *dns.Server
|
||||
resolver *dns.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
statsMutex sync.Mutex
|
||||
stats *Stats
|
||||
blockedDomainsMutex sync.RWMutex
|
||||
blockedDomains map[string]*BlockedDomain
|
||||
hourlyStatsMutex sync.RWMutex
|
||||
hourlyStats map[string]int64 // 按小时统计屏蔽数量
|
||||
}
|
||||
|
||||
// Stats DNS服务器统计信息
|
||||
type Stats struct {
|
||||
Queries int64
|
||||
Blocked int64
|
||||
Allowed int64
|
||||
Errors int64
|
||||
LastQuery time.Time
|
||||
}
|
||||
|
||||
// NewServer 创建DNS服务器实例
|
||||
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
config: config,
|
||||
shieldConfig: shieldConfig,
|
||||
shieldManager: shieldManager,
|
||||
resolver: &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: time.Duration(config.Timeout) * time.Millisecond,
|
||||
},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stats: &Stats{
|
||||
Queries: 0,
|
||||
Blocked: 0,
|
||||
Allowed: 0,
|
||||
Errors: 0,
|
||||
},
|
||||
blockedDomains: make(map[string]*BlockedDomain),
|
||||
hourlyStats: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动DNS服务器
|
||||
func (s *Server) Start() error {
|
||||
s.server = &dns.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.config.Port),
|
||||
Net: "udp",
|
||||
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
||||
}
|
||||
|
||||
// 启动TCP服务器(用于大型响应)
|
||||
tcpServer := &dns.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.config.Port),
|
||||
Net: "tcp",
|
||||
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
||||
}
|
||||
|
||||
// 启动UDP服务
|
||||
go func() {
|
||||
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
logger.Error("DNS UDP服务器启动失败", "error", err)
|
||||
s.cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动TCP服务
|
||||
go func() {
|
||||
logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port))
|
||||
if err := tcpServer.ListenAndServe(); err != nil {
|
||||
logger.Error("DNS TCP服务器启动失败", "error", err)
|
||||
s.cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待停止信号
|
||||
<-s.ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止DNS服务器
|
||||
func (s *Server) Stop() {
|
||||
if s.server != nil {
|
||||
s.server.Shutdown()
|
||||
}
|
||||
s.cancel()
|
||||
logger.Info("DNS服务器已停止")
|
||||
}
|
||||
|
||||
// handleDNSRequest 处理DNS请求
|
||||
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Queries++
|
||||
stats.LastQuery = time.Now()
|
||||
})
|
||||
|
||||
// 只处理递归查询
|
||||
if r.RecursionDesired == false {
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
response.SetRcode(r, dns.RcodeRefused)
|
||||
w.WriteMsg(response)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询域名
|
||||
var domain string
|
||||
if len(r.Question) > 0 {
|
||||
domain = r.Question[0].Name
|
||||
// 移除末尾的点
|
||||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||||
domain = domain[:len(domain)-1]
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("接收到DNS查询", "domain", domain, "type", r.Question[0].Qtype, "client", w.RemoteAddr())
|
||||
|
||||
// 检查hosts文件是否有匹配
|
||||
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
|
||||
s.handleHostsResponse(w, r, ip)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否被屏蔽
|
||||
if s.shieldManager.IsBlocked(domain) {
|
||||
s.handleBlockedResponse(w, r, domain)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发到上游DNS服务器
|
||||
s.forwardDNSRequest(w, r, domain)
|
||||
}
|
||||
|
||||
// handleHostsResponse 处理hosts文件匹配的响应
|
||||
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
|
||||
if len(r.Question) > 0 {
|
||||
q := r.Question[0]
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
}
|
||||
answer.A = net.ParseIP(ip)
|
||||
response.Answer = append(response.Answer, answer)
|
||||
}
|
||||
|
||||
w.WriteMsg(response)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Allowed++
|
||||
})
|
||||
}
|
||||
|
||||
// handleBlockedResponse 处理被屏蔽的域名响应
|
||||
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||||
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
|
||||
|
||||
// 更新被屏蔽域名统计
|
||||
s.updateBlockedDomainStats(domain)
|
||||
|
||||
// 更新总体统计
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Blocked++
|
||||
})
|
||||
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
|
||||
// 获取屏蔽方法配置
|
||||
blockMethod := "NXDOMAIN" // 默认值
|
||||
customBlockIP := "" // 默认值
|
||||
|
||||
// 从Server结构体的shieldConfig字段获取配置
|
||||
if s.shieldConfig != nil {
|
||||
blockMethod = s.shieldConfig.BlockMethod
|
||||
customBlockIP = s.shieldConfig.CustomBlockIP
|
||||
}
|
||||
|
||||
// 根据屏蔽方法返回不同的响应
|
||||
switch blockMethod {
|
||||
case "refused":
|
||||
// 返回拒绝查询响应
|
||||
response.SetRcode(r, dns.RcodeRefused)
|
||||
case "emptyIP":
|
||||
// 返回空IP响应
|
||||
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
}
|
||||
answer.A = net.ParseIP("0.0.0.0") // 空IP
|
||||
response.Answer = append(response.Answer, answer)
|
||||
}
|
||||
case "customIP":
|
||||
// 返回自定义IP响应
|
||||
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
}
|
||||
answer.A = net.ParseIP(customBlockIP)
|
||||
response.Answer = append(response.Answer, answer)
|
||||
}
|
||||
case "NXDOMAIN", "":
|
||||
fallthrough // 默认使用NXDOMAIN
|
||||
default:
|
||||
// 返回NXDOMAIN响应(域名不存在)
|
||||
response.SetRcode(r, dns.RcodeNameError)
|
||||
}
|
||||
|
||||
w.WriteMsg(response)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Blocked++
|
||||
})
|
||||
}
|
||||
|
||||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||||
// 尝试所有上游DNS服务器
|
||||
for _, upstream := range s.config.UpstreamDNS {
|
||||
response, rtt, err := s.resolver.Exchange(r, upstream)
|
||||
if err == nil && response != nil && response.Rcode == dns.RcodeSuccess {
|
||||
// 设置递归可用标志
|
||||
response.RecursionAvailable = true
|
||||
|
||||
w.WriteMsg(response)
|
||||
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
|
||||
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Allowed++
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 所有上游服务器都失败,返回服务器失败错误
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
response.SetRcode(r, dns.RcodeServerFailure)
|
||||
w.WriteMsg(response)
|
||||
|
||||
logger.Error("DNS查询失败", "domain", domain)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Errors++
|
||||
})
|
||||
}
|
||||
|
||||
// updateBlockedDomainStats 更新被屏蔽域名统计
|
||||
func (s *Server) updateBlockedDomainStats(domain string) {
|
||||
// 更新被屏蔽域名计数
|
||||
s.blockedDomainsMutex.Lock()
|
||||
defer s.blockedDomainsMutex.Unlock()
|
||||
|
||||
if entry, exists := s.blockedDomains[domain]; exists {
|
||||
entry.Count++
|
||||
entry.LastSeen = time.Now()
|
||||
} else {
|
||||
s.blockedDomains[domain] = &BlockedDomain{
|
||||
Domain: domain,
|
||||
Count: 1,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// 更新小时统计
|
||||
hourKey := time.Now().Format("2006-01-02-15")
|
||||
s.hourlyStats[hourKey]++
|
||||
}
|
||||
|
||||
// updateStats 更新统计信息
|
||||
func (s *Server) updateStats(update func(*Stats)) {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
update(s.stats)
|
||||
}
|
||||
|
||||
// GetStats 获取DNS服务器统计信息
|
||||
func (s *Server) GetStats() *Stats {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
|
||||
// 返回统计信息的副本
|
||||
return &Stats{
|
||||
Queries: s.stats.Queries,
|
||||
Blocked: s.stats.Blocked,
|
||||
Allowed: s.stats.Allowed,
|
||||
Errors: s.stats.Errors,
|
||||
LastQuery: s.stats.LastQuery,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTopBlockedDomains 获取TOP屏蔽域名列表
|
||||
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
|
||||
s.blockedDomainsMutex.RLock()
|
||||
defer s.blockedDomainsMutex.RUnlock()
|
||||
|
||||
// 转换为切片
|
||||
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
|
||||
for _, entry := range s.blockedDomains {
|
||||
domains = append(domains, *entry)
|
||||
}
|
||||
|
||||
// 按计数排序
|
||||
sort.Slice(domains, func(i, j int) bool {
|
||||
return domains[i].Count > domains[j].Count
|
||||
})
|
||||
|
||||
// 返回限制数量
|
||||
if len(domains) > limit {
|
||||
return domains[:limit]
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
|
||||
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
|
||||
s.blockedDomainsMutex.RLock()
|
||||
defer s.blockedDomainsMutex.RUnlock()
|
||||
|
||||
// 转换为切片
|
||||
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
|
||||
for _, entry := range s.blockedDomains {
|
||||
domains = append(domains, *entry)
|
||||
}
|
||||
|
||||
// 按时间排序
|
||||
sort.Slice(domains, func(i, j int) bool {
|
||||
return domains[i].LastSeen.After(domains[j].LastSeen)
|
||||
})
|
||||
|
||||
// 返回限制数量
|
||||
if len(domains) > limit {
|
||||
return domains[:limit]
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// GetHourlyStats 获取24小时屏蔽统计
|
||||
func (s *Server) GetHourlyStats() map[string]int64 {
|
||||
s.hourlyStatsMutex.RLock()
|
||||
defer s.hourlyStatsMutex.RUnlock()
|
||||
|
||||
// 返回副本
|
||||
result := make(map[string]int64)
|
||||
for k, v := range s.hourlyStats {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user