Files
dns-server/http/domain_info_handlers.go
Alex Yang f9e2e5a6bc update
2026-04-12 21:40:22 +08:00

386 lines
12 KiB
Go

package http
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"dns-server/config"
"dns-server/logger"
"dns-server/shield"
"gopkg.in/ini.v1"
)
// handleDomainInfoUpdate 处理域名信息更新请求(更新所有类型)
func (s *Server) handleDomainInfoUpdate(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("收到更新所有域名信息的请求")
// 异步更新所有域名信息
go func() {
if err := s.domainInfoManager.LoadDomainInfo(); err != nil {
logger.Error("更新所有域名信息失败", "error", err)
} else {
logger.Info("更新所有域名信息成功")
}
}()
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
"message": "域名信息更新任务已启动",
})
}
// handleDomainInfoUpdateByType 处理指定类型的域名信息更新请求
func (s *Server) handleDomainInfoUpdateByType(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 从 URL 路径中提取类型
// 路径格式:/api/domain-info/update/{type}
pathParts := strings.Split(r.URL.Path, "/")
var entryType string
for i, part := range pathParts {
if part == "update" && i+1 < len(pathParts) {
entryType = strings.TrimSpace(pathParts[i+1])
break
}
}
if entryType == "" {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "类型参数不能为空",
})
return
}
logger.Info("收到更新域名信息的请求", "type", entryType)
// 异步更新指定类型的域名信息
go func(typeToUpdate string) {
if err := s.domainInfoManager.UpdateDomainInfoList(typeToUpdate); err != nil {
logger.Error("更新域名信息失败", "type", typeToUpdate, "error", err)
} else {
logger.Info("更新域名信息成功", "type", typeToUpdate)
}
}(entryType)
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
"message": "域名信息更新任务已启动",
"type": entryType,
})
}
// handleDomainInfoQuery 处理单个域名信息查询请求
func (s *Server) handleDomainInfoQuery(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domain := r.URL.Query().Get("domain")
if domain == "" {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "domain 参数不能为空",
})
return
}
result := map[string]interface{}{
"domain": domain,
}
// 查询域名信息
if domainInfo := s.domainInfoManager.GetDomainInfo(domain); domainInfo != nil {
result["domainInfo"] = domainInfo
}
// 查询威胁信息
if threatInfo := s.domainInfoManager.GetThreatInfo(domain); threatInfo != nil {
result["threatInfo"] = threatInfo
}
// 查询跟踪器信息
if trackerInfo := s.domainInfoManager.GetTrackerInfo(domain); trackerInfo != nil {
result["trackerInfo"] = trackerInfo
}
json.NewEncoder(w).Encode(result)
}
// AddDomainInfoList 添加域名信息列表(供外部调用)
func (s *Server) AddDomainInfoList(entry config.DomainInfoEntry) error {
return s.domainInfoManager.AddDomainInfoList(entry)
}
// RemoveDomainInfoList 移除域名信息列表(供外部调用)
func (s *Server) RemoveDomainInfoList(entryType string) error {
return s.domainInfoManager.RemoveDomainInfoList(entryType)
}
// GetDomainInfoManager 获取域名信息管理器实例(供外部调用)
func (s *Server) GetDomainInfoManager() *shield.DomainInfoManager {
return s.domainInfoManager
}
// handleDomainInfoStatus 处理域名信息更新状态请求
func (s *Server) handleDomainInfoStatus(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
status := s.domainInfoManager.GetUpdateStatus()
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "success",
"data": status,
})
}
// handleDomainInfoRefresh 处理域名信息缓存刷新请求
func (s *Server) handleDomainInfoRefresh(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("收到刷新域名信息缓存的请求")
// 异步刷新缓存
go func() {
if err := s.domainInfoManager.LoadDomainInfo(); err != nil {
logger.Error("刷新域名信息缓存失败", "error", err)
} else {
logger.Info("刷新域名信息缓存成功")
}
}()
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
"message": "域名信息缓存刷新任务已启动",
})
}
// handleDomainInfoAdd 处理添加域名信息列表请求
func (s *Server) handleDomainInfoAdd(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var entry config.DomainInfoEntry
if err := json.NewDecoder(r.Body).Decode(&entry); err != nil {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "无效的请求数据",
})
return
}
// 验证必填字段
if entry.Name == "" || entry.URL == "" || entry.Type == "" {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "名称、URL和类型不能为空",
})
return
}
logger.Info("收到添加域名信息列表的请求", "name", entry.Name, "type", entry.Type)
if err := s.domainInfoManager.AddDomainInfoList(entry); err != nil {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": fmt.Sprintf("添加列表失败: %v", err),
})
return
}
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "添加列表成功,但保存配置失败",
"name": entry.Name,
"type": entry.Type,
})
return
}
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
"message": "域名信息列表添加成功",
"name": entry.Name,
"type": entry.Type,
})
}
// handleDomainInfoRemove 处理删除域名信息列表请求
func (s *Server) handleDomainInfoRemove(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var request struct {
Type string `json:"type"`
}
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "无效的请求数据",
})
return
}
if request.Type == "" {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "类型参数不能为空",
})
return
}
logger.Info("收到删除域名信息列表的请求", "type", request.Type)
if err := s.domainInfoManager.RemoveDomainInfoList(request.Type); err != nil {
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": fmt.Sprintf("删除列表失败: %v", err),
})
return
}
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{
"status": "error",
"message": "删除列表成功,但保存配置失败",
"type": request.Type,
})
return
}
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
"message": "域名信息列表删除成功",
"type": request.Type,
})
}
// saveConfigToFile 保存配置到文件
func saveConfigToFile(config *config.Config, filePath string) error {
// 创建新的INI文件
cfg := ini.Empty()
// DNS配置
dnsSection := cfg.Section("dns")
dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port))
dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", "))
dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", "))
dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval))
dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL))
dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC))
dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode)
dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout))
dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn))
dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", "))
dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6))
dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode)
dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize))
dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL))
dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL))
// 域名特定DNS服务器配置
for domain, servers := range config.DNS.DomainSpecificDNS {
dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", "))
}
// HTTP配置
httpSection := cfg.Section("http")
httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port))
httpSection.Key("host").SetValue(config.HTTP.Host)
httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI))
httpSection.Key("username").SetValue(config.HTTP.Username)
httpSection.Key("password").SetValue(config.HTTP.Password)
// Shield配置
shieldSection := cfg.Section("shield")
shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval))
shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod)
shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP)
shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval))
// 黑名单配置
for _, bl := range config.Shield.Blacklists {
shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled))
}
// GFWList配置
gfwListSection := cfg.Section("gfwList")
gfwListSection.Key("ip").SetValue(config.GFWList.IP)
gfwListSection.Key("content").SetValue(config.GFWList.Content)
gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled))
// Log配置
logSection := cfg.Section("log")
logSection.Key("level").SetValue(config.Log.Level)
logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize))
logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups))
logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge))
// Threat配置
threatSection := cfg.Section("threat")
threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled))
threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold))
threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold))
threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength))
threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ", "))
threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ", "))
threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays))
threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath)
// DomainInfo配置
domainInfoSection := cfg.Section("domainInfo")
domainInfoSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.DomainInfo.UpdateInterval))
domainInfoSection.Key("enableAutoUpdate").SetValue(fmt.Sprintf("%t", config.DomainInfo.EnableAutoUpdate))
// 域名信息列表配置
for _, entry := range config.DomainInfo.DomainInfoLists {
domainInfoSection.Key(fmt.Sprintf("domainInfo_%s", entry.Name)).SetValue(fmt.Sprintf("%s,%s,%t", entry.URL, entry.Type, entry.Enabled))
}
// 保存到文件
return cfg.SaveTo(filePath)
}