update
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"dns-server/log"
|
||||
)
|
||||
|
||||
// enrichLogsWithDomainInfo 为日志条目添加域名信息、威胁信息和跟踪器信息
|
||||
func (s *Server) enrichLogsWithDomainInfo(logs []log.QueryLog) []map[string]interface{} {
|
||||
enrichedLogs := make([]map[string]interface{}, len(logs))
|
||||
|
||||
for i, logItem := range logs {
|
||||
// 将日志条目转换为 map
|
||||
logMap := make(map[string]interface{})
|
||||
logMap["timestamp"] = logItem.Timestamp
|
||||
logMap["clientIP"] = logItem.ClientIP
|
||||
logMap["domain"] = logItem.Domain
|
||||
logMap["queryType"] = logItem.QueryType
|
||||
logMap["result"] = logItem.Result
|
||||
logMap["responseTime"] = logItem.ResponseTime
|
||||
logMap["dnsServer"] = logItem.DNSServer
|
||||
logMap["answers"] = logItem.Answers
|
||||
|
||||
// 查询域名信息
|
||||
if domainInfo := s.domainInfoManager.GetDomainInfo(logItem.Domain); domainInfo != nil {
|
||||
logMap["domainInfo"] = domainInfo
|
||||
}
|
||||
|
||||
// 查询威胁信息
|
||||
if threatInfo := s.domainInfoManager.GetThreatInfo(logItem.Domain); threatInfo != nil {
|
||||
logMap["threatInfo"] = threatInfo
|
||||
}
|
||||
|
||||
// 查询跟踪器信息
|
||||
if trackerInfo := s.domainInfoManager.GetTrackerInfo(logItem.Domain); trackerInfo != nil {
|
||||
logMap["trackerInfo"] = trackerInfo
|
||||
}
|
||||
|
||||
enrichedLogs[i] = logMap
|
||||
}
|
||||
|
||||
return enrichedLogs
|
||||
}
|
||||
|
||||
// enrichDomainInfoWithDetails 为域名查询结果添加详细信息
|
||||
func (s *Server) enrichDomainInfoWithDetails(domain string, baseResult map[string]interface{}) map[string]interface{} {
|
||||
// 添加域名信息
|
||||
if domainInfo := s.domainInfoManager.GetDomainInfo(domain); domainInfo != nil {
|
||||
baseResult["domainInfo"] = domainInfo
|
||||
}
|
||||
|
||||
// 添加威胁信息
|
||||
if threatInfo := s.domainInfoManager.GetThreatInfo(domain); threatInfo != nil {
|
||||
baseResult["threatInfo"] = threatInfo
|
||||
baseResult["isThreat"] = true
|
||||
}
|
||||
|
||||
// 添加跟踪器信息
|
||||
if trackerInfo := s.domainInfoManager.GetTrackerInfo(domain); trackerInfo != nil {
|
||||
baseResult["trackerInfo"] = trackerInfo
|
||||
baseResult["isTracker"] = true
|
||||
}
|
||||
|
||||
return baseResult
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"dns-server/config"
|
||||
"dns-server/logger"
|
||||
"dns-server/shield"
|
||||
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// handleDomainInfoUpdate 处理域名信息更新请求(更新所有类型)
|
||||
func (s *Server) handleDomainInfoUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("收到更新所有域名信息的请求")
|
||||
|
||||
// 异步更新所有域名信息
|
||||
go func() {
|
||||
if err := s.domainInfoManager.LoadDomainInfo(); err != nil {
|
||||
logger.Error("更新所有域名信息失败", "error", err)
|
||||
} else {
|
||||
logger.Info("更新所有域名信息成功")
|
||||
}
|
||||
}()
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "success",
|
||||
"message": "域名信息更新任务已启动",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDomainInfoUpdateByType 处理指定类型的域名信息更新请求
|
||||
func (s *Server) handleDomainInfoUpdateByType(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 从 URL 路径中提取类型
|
||||
// 路径格式:/api/domain-info/update/{type}
|
||||
pathParts := strings.Split(r.URL.Path, "/")
|
||||
var entryType string
|
||||
for i, part := range pathParts {
|
||||
if part == "update" && i+1 < len(pathParts) {
|
||||
entryType = strings.TrimSpace(pathParts[i+1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if entryType == "" {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "类型参数不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("收到更新域名信息的请求", "type", entryType)
|
||||
|
||||
// 异步更新指定类型的域名信息
|
||||
go func(typeToUpdate string) {
|
||||
if err := s.domainInfoManager.UpdateDomainInfoList(typeToUpdate); err != nil {
|
||||
logger.Error("更新域名信息失败", "type", typeToUpdate, "error", err)
|
||||
} else {
|
||||
logger.Info("更新域名信息成功", "type", typeToUpdate)
|
||||
}
|
||||
}(entryType)
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "success",
|
||||
"message": "域名信息更新任务已启动",
|
||||
"type": entryType,
|
||||
})
|
||||
}
|
||||
|
||||
// handleDomainInfoQuery 处理单个域名信息查询请求
|
||||
func (s *Server) handleDomainInfoQuery(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
domain := r.URL.Query().Get("domain")
|
||||
if domain == "" {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "domain 参数不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
// 查询域名信息
|
||||
if domainInfo := s.domainInfoManager.GetDomainInfo(domain); domainInfo != nil {
|
||||
result["domainInfo"] = domainInfo
|
||||
}
|
||||
|
||||
// 查询威胁信息
|
||||
if threatInfo := s.domainInfoManager.GetThreatInfo(domain); threatInfo != nil {
|
||||
result["threatInfo"] = threatInfo
|
||||
}
|
||||
|
||||
// 查询跟踪器信息
|
||||
if trackerInfo := s.domainInfoManager.GetTrackerInfo(domain); trackerInfo != nil {
|
||||
result["trackerInfo"] = trackerInfo
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// AddDomainInfoList 添加域名信息列表(供外部调用)
|
||||
func (s *Server) AddDomainInfoList(entry config.DomainInfoEntry) error {
|
||||
return s.domainInfoManager.AddDomainInfoList(entry)
|
||||
}
|
||||
|
||||
// RemoveDomainInfoList 移除域名信息列表(供外部调用)
|
||||
func (s *Server) RemoveDomainInfoList(entryType string) error {
|
||||
return s.domainInfoManager.RemoveDomainInfoList(entryType)
|
||||
}
|
||||
|
||||
// GetDomainInfoManager 获取域名信息管理器实例(供外部调用)
|
||||
func (s *Server) GetDomainInfoManager() *shield.DomainInfoManager {
|
||||
return s.domainInfoManager
|
||||
}
|
||||
|
||||
// handleDomainInfoStatus 处理域名信息更新状态请求
|
||||
func (s *Server) handleDomainInfoStatus(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
status := s.domainInfoManager.GetUpdateStatus()
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "success",
|
||||
"data": status,
|
||||
})
|
||||
}
|
||||
|
||||
// handleDomainInfoRefresh 处理域名信息缓存刷新请求
|
||||
func (s *Server) handleDomainInfoRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("收到刷新域名信息缓存的请求")
|
||||
|
||||
// 异步刷新缓存
|
||||
go func() {
|
||||
if err := s.domainInfoManager.LoadDomainInfo(); err != nil {
|
||||
logger.Error("刷新域名信息缓存失败", "error", err)
|
||||
} else {
|
||||
logger.Info("刷新域名信息缓存成功")
|
||||
}
|
||||
}()
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "success",
|
||||
"message": "域名信息缓存刷新任务已启动",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDomainInfoAdd 处理添加域名信息列表请求
|
||||
func (s *Server) handleDomainInfoAdd(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var entry config.DomainInfoEntry
|
||||
if err := json.NewDecoder(r.Body).Decode(&entry); err != nil {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "无效的请求数据",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if entry.Name == "" || entry.URL == "" || entry.Type == "" {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "名称、URL和类型不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("收到添加域名信息列表的请求", "name", entry.Name, "type", entry.Type)
|
||||
|
||||
if err := s.domainInfoManager.AddDomainInfoList(entry); err != nil {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": fmt.Sprintf("添加列表失败: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||||
logger.Error("保存配置文件失败", "error", err)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "添加列表成功,但保存配置失败",
|
||||
"name": entry.Name,
|
||||
"type": entry.Type,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "success",
|
||||
"message": "域名信息列表添加成功",
|
||||
"name": entry.Name,
|
||||
"type": entry.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// handleDomainInfoRemove 处理删除域名信息列表请求
|
||||
func (s *Server) handleDomainInfoRemove(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "无效的请求数据",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if request.Type == "" {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "类型参数不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("收到删除域名信息列表的请求", "type", request.Type)
|
||||
|
||||
if err := s.domainInfoManager.RemoveDomainInfoList(request.Type); err != nil {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": fmt.Sprintf("删除列表失败: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||||
logger.Error("保存配置文件失败", "error", err)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "error",
|
||||
"message": "删除列表成功,但保存配置失败",
|
||||
"type": request.Type,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "success",
|
||||
"message": "域名信息列表删除成功",
|
||||
"type": request.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// saveConfigToFile 保存配置到文件
|
||||
func saveConfigToFile(config *config.Config, filePath string) error {
|
||||
// 创建新的INI文件
|
||||
cfg := ini.Empty()
|
||||
|
||||
// DNS配置
|
||||
dnsSection := cfg.Section("dns")
|
||||
dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port))
|
||||
dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", "))
|
||||
dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", "))
|
||||
dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval))
|
||||
dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL))
|
||||
dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC))
|
||||
dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode)
|
||||
dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout))
|
||||
dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn))
|
||||
dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", "))
|
||||
dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6))
|
||||
dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode)
|
||||
dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize))
|
||||
dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL))
|
||||
dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL))
|
||||
|
||||
// 域名特定DNS服务器配置
|
||||
for domain, servers := range config.DNS.DomainSpecificDNS {
|
||||
dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", "))
|
||||
}
|
||||
|
||||
// HTTP配置
|
||||
httpSection := cfg.Section("http")
|
||||
httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port))
|
||||
httpSection.Key("host").SetValue(config.HTTP.Host)
|
||||
httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI))
|
||||
httpSection.Key("username").SetValue(config.HTTP.Username)
|
||||
httpSection.Key("password").SetValue(config.HTTP.Password)
|
||||
|
||||
// Shield配置
|
||||
shieldSection := cfg.Section("shield")
|
||||
shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval))
|
||||
shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod)
|
||||
shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP)
|
||||
shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval))
|
||||
|
||||
// 黑名单配置
|
||||
for _, bl := range config.Shield.Blacklists {
|
||||
shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled))
|
||||
}
|
||||
|
||||
// GFWList配置
|
||||
gfwListSection := cfg.Section("gfwList")
|
||||
gfwListSection.Key("ip").SetValue(config.GFWList.IP)
|
||||
gfwListSection.Key("content").SetValue(config.GFWList.Content)
|
||||
gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled))
|
||||
|
||||
// Log配置
|
||||
logSection := cfg.Section("log")
|
||||
logSection.Key("level").SetValue(config.Log.Level)
|
||||
logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize))
|
||||
logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups))
|
||||
logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge))
|
||||
|
||||
// Threat配置
|
||||
threatSection := cfg.Section("threat")
|
||||
threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled))
|
||||
threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold))
|
||||
threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold))
|
||||
threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength))
|
||||
threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ", "))
|
||||
threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ", "))
|
||||
threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays))
|
||||
threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath)
|
||||
|
||||
// DomainInfo配置
|
||||
domainInfoSection := cfg.Section("domainInfo")
|
||||
domainInfoSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.DomainInfo.UpdateInterval))
|
||||
domainInfoSection.Key("enableAutoUpdate").SetValue(fmt.Sprintf("%t", config.DomainInfo.EnableAutoUpdate))
|
||||
|
||||
// 域名信息列表配置
|
||||
for _, entry := range config.DomainInfo.DomainInfoLists {
|
||||
domainInfoSection.Key(fmt.Sprintf("domainInfo_%s", entry.Name)).SetValue(fmt.Sprintf("%s,%s,%t", entry.URL, entry.Type, entry.Enabled))
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
return cfg.SaveTo(filePath)
|
||||
}
|
||||
|
||||
|
||||
+112
-173
@@ -21,7 +21,6 @@ import (
|
||||
"dns-server/threat"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// CacheEntry 缓存条目
|
||||
@@ -308,15 +307,16 @@ func (s *Server) ClearAllCache() {
|
||||
|
||||
// Server HTTP控制台服务器
|
||||
type Server struct {
|
||||
globalConfig *config.Config
|
||||
config *config.HTTPConfig
|
||||
dnsServer *dns.Server
|
||||
shieldManager *shield.ShieldManager
|
||||
gfwManager *gfw.GFWListManager
|
||||
server *http.Server
|
||||
globalConfig *config.Config
|
||||
config *config.HTTPConfig
|
||||
dnsServer *dns.Server
|
||||
shieldManager *shield.ShieldManager
|
||||
gfwManager *gfw.GFWListManager
|
||||
domainInfoManager *shield.DomainInfoManager
|
||||
server *http.Server
|
||||
|
||||
// 会话管理相关字段
|
||||
sessions map[string]time.Time // 会话ID到过期时间的映射
|
||||
sessions map[string]time.Time // 会话 ID 到过期时间的映射
|
||||
sessionsMutex sync.Mutex // 会话映射的互斥锁
|
||||
sessionTTL time.Duration // 会话过期时间
|
||||
|
||||
@@ -334,18 +334,22 @@ type Server struct {
|
||||
cacheMaxSize int // 缓存最大条目数
|
||||
}
|
||||
|
||||
// NewServer 创建HTTP服务器实例
|
||||
// NewServer 创建 HTTP 服务器实例
|
||||
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server {
|
||||
// 创建域名信息管理器
|
||||
domainInfoManager := shield.NewDomainInfoManager(&globalConfig.DomainInfo)
|
||||
|
||||
server := &Server{
|
||||
globalConfig: globalConfig,
|
||||
config: &globalConfig.HTTP,
|
||||
dnsServer: dnsServer,
|
||||
shieldManager: shieldManager,
|
||||
gfwManager: gfwManager,
|
||||
globalConfig: globalConfig,
|
||||
config: &globalConfig.HTTP,
|
||||
dnsServer: dnsServer,
|
||||
shieldManager: shieldManager,
|
||||
gfwManager: gfwManager,
|
||||
domainInfoManager: domainInfoManager,
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
// 允许所有CORS请求
|
||||
// 允许所有 CORS 请求
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
@@ -363,6 +367,9 @@ func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager
|
||||
cacheMaxSize: 100, // 默认最大 100 条
|
||||
}
|
||||
|
||||
// 启动域名信息管理器
|
||||
domainInfoManager.Start()
|
||||
|
||||
// 启动广播协程
|
||||
go server.startBroadcastLoop()
|
||||
// 启动会话清理协程
|
||||
@@ -443,7 +450,18 @@ func (s *Server) Start() error {
|
||||
mux.HandleFunc("/api/logs/archives", s.loginRequired(s.handleArchiveList))
|
||||
mux.HandleFunc("/api/logs/archive-cleanup", s.loginRequired(s.handleArchiveCleanup))
|
||||
// 域名信息列表接口
|
||||
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
|
||||
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
|
||||
// 域名信息更新接口
|
||||
mux.HandleFunc("/api/domain-info/update", s.loginRequired(s.handleDomainInfoUpdate))
|
||||
mux.HandleFunc("/api/domain-info/update/{type}", s.loginRequired(s.handleDomainInfoUpdateByType))
|
||||
// 域名信息状态接口
|
||||
mux.HandleFunc("/api/domain-info/status", s.loginRequired(s.handleDomainInfoStatus))
|
||||
// 域名信息缓存刷新接口
|
||||
mux.HandleFunc("/api/domain-info/refresh", s.loginRequired(s.handleDomainInfoRefresh))
|
||||
// 域名信息列表添加接口
|
||||
mux.HandleFunc("/api/domain-info/add", s.loginRequired(s.handleDomainInfoAdd))
|
||||
// 域名信息列表删除接口
|
||||
mux.HandleFunc("/api/domain-info/remove", s.loginRequired(s.handleDomainInfoRemove))
|
||||
// 威胁查询接口
|
||||
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
|
||||
// 威胁批量查询接口
|
||||
@@ -1583,81 +1601,7 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// saveConfigToFile 保存配置到文件
|
||||
func saveConfigToFile(config *config.Config, filePath string) error {
|
||||
// 创建新的INI文件
|
||||
cfg := ini.Empty()
|
||||
|
||||
// DNS配置
|
||||
dnsSection := cfg.Section("dns")
|
||||
dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port))
|
||||
dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", "))
|
||||
dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", "))
|
||||
dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval))
|
||||
dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL))
|
||||
dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC))
|
||||
dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode)
|
||||
dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout))
|
||||
dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn))
|
||||
dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", "))
|
||||
dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6))
|
||||
dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode)
|
||||
dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize))
|
||||
dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL))
|
||||
dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL))
|
||||
|
||||
// 域名特定DNS服务器配置
|
||||
for domain, servers := range config.DNS.DomainSpecificDNS {
|
||||
dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", "))
|
||||
}
|
||||
|
||||
// HTTP配置
|
||||
httpSection := cfg.Section("http")
|
||||
httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port))
|
||||
httpSection.Key("host").SetValue(config.HTTP.Host)
|
||||
httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI))
|
||||
httpSection.Key("username").SetValue(config.HTTP.Username)
|
||||
httpSection.Key("password").SetValue(config.HTTP.Password)
|
||||
|
||||
// Shield配置
|
||||
shieldSection := cfg.Section("shield")
|
||||
shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval))
|
||||
shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod)
|
||||
shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP)
|
||||
shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval))
|
||||
|
||||
// 黑名单配置
|
||||
for _, bl := range config.Shield.Blacklists {
|
||||
shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled))
|
||||
}
|
||||
|
||||
// GFWList配置
|
||||
gfwListSection := cfg.Section("gfwList")
|
||||
gfwListSection.Key("ip").SetValue(config.GFWList.IP)
|
||||
gfwListSection.Key("content").SetValue(config.GFWList.Content)
|
||||
gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled))
|
||||
|
||||
// Log配置
|
||||
logSection := cfg.Section("log")
|
||||
logSection.Key("level").SetValue(config.Log.Level)
|
||||
logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize))
|
||||
logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups))
|
||||
logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge))
|
||||
|
||||
// Threat配置
|
||||
threatSection := cfg.Section("threat")
|
||||
threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled))
|
||||
threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold))
|
||||
threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold))
|
||||
threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength))
|
||||
threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ","))
|
||||
threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ","))
|
||||
threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays))
|
||||
threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath)
|
||||
|
||||
// 保存到文件
|
||||
return cfg.SaveTo(filePath)
|
||||
}
|
||||
|
||||
// handleConfig 处理配置请求
|
||||
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -2168,6 +2112,12 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
"totalPages": totalPages,
|
||||
}
|
||||
|
||||
// 为日志添加域名信息(只在第一页添加,减少开销)
|
||||
if pageNum == 1 && len(logs) > 0 {
|
||||
enrichedLogs := s.enrichLogsWithDomainInfo(logs)
|
||||
response["logs"] = enrichedLogs
|
||||
}
|
||||
|
||||
// 存入缓存(只缓存第一页,因为用户最常查看第一页)
|
||||
if s.cacheEnabled && pageNum == 1 {
|
||||
s.queryCache.Set(cacheKey, response)
|
||||
@@ -2416,9 +2366,40 @@ func (s *Server) handleDomainInfoList(w http.ResponseWriter, r *http.Request) {
|
||||
threatFilter := query.Get("threats")
|
||||
handleThreatsInfo(w, threatFilter)
|
||||
} else {
|
||||
// 直接访问 /domain-info 不提供任何内容
|
||||
http.Error(w, "No content provided", http.StatusNoContent)
|
||||
return
|
||||
// 获取所有域名信息列表
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.domainInfoManager == nil {
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"lists": []interface{}{},
|
||||
"domainInfoCount": 0,
|
||||
"threatCount": 0,
|
||||
"trackerCount": 0,
|
||||
"lastUpdateTime": "从未更新",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 goroutine 和 channel 避免死锁
|
||||
done := make(chan map[string]interface{}, 1)
|
||||
go func() {
|
||||
domainInfo := s.domainInfoManager.GetAllDomainInfo()
|
||||
done <- domainInfo
|
||||
}()
|
||||
|
||||
select {
|
||||
case domainInfo := <-done:
|
||||
json.NewEncoder(w).Encode(domainInfo)
|
||||
case <-time.After(10 * time.Second):
|
||||
// 超时,返回空响应
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"lists": []interface{}{},
|
||||
"domainInfoCount": 0,
|
||||
"threatCount": 0,
|
||||
"trackerCount": 0,
|
||||
"lastUpdateTime": "加载中...",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2930,95 +2911,53 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 读取威胁数据库 CSV 文件
|
||||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
|
||||
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
|
||||
return
|
||||
// 使用域名信息管理器查询威胁信息
|
||||
threatInfo := s.domainInfoManager.GetThreatInfo(domain)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
// 解析 CSV
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||||
|
||||
// 读取所有记录
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
|
||||
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 构建威胁域名映射(支持顶级域名匹配)
|
||||
threatMap := make(map[string][]string)
|
||||
for i, record := range records {
|
||||
if i == 0 {
|
||||
continue // 跳过标题行
|
||||
}
|
||||
if len(record) >= 4 {
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatInfo := []string{threatType, threatName, riskLevel}
|
||||
|
||||
// 1. 完整域名匹配(所有类型都添加)
|
||||
threatMap[domain] = threatInfo
|
||||
|
||||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||||
// 类型判断:钓鱼网站、仿冒网站
|
||||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||||
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
|
||||
// 对于恶意网站,添加子域名匹配规则
|
||||
// 例如:sub.example.com -> 添加 .sub.example.com 规则
|
||||
// 这样 a.sub.example.com 就会匹配
|
||||
topLevelDomain := "." + domain
|
||||
// 只有当该顶级域名规则不存在时才添加
|
||||
if _, exists := threatMap[topLevelDomain]; !exists {
|
||||
threatMap[topLevelDomain] = threatInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询单个域名
|
||||
var result string
|
||||
|
||||
// 1. 先检查完整匹配
|
||||
if threat, exists := threatMap[domain]; exists {
|
||||
result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain)
|
||||
if threatInfo != nil {
|
||||
// 从内存中的威胁数据库获取信息
|
||||
result["isThreat"] = true
|
||||
result["data"] = threatInfo
|
||||
result["threatType"] = threatInfo["type"]
|
||||
result["threatName"] = threatInfo["name"]
|
||||
result["riskLevel"] = threatInfo["riskLevel"]
|
||||
} else {
|
||||
// 2. 检查子域名匹配(遍历顶级域名规则)
|
||||
for threatDomain, threatInfo := range threatMap {
|
||||
// 只检查以点开头的顶级域名规则
|
||||
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
|
||||
// 额外验证:确保是完整的域名部分匹配
|
||||
prefix := strings.TrimSuffix(domain, threatDomain)
|
||||
if len(prefix) > 0 && !strings.HasSuffix(prefix, ".") {
|
||||
// 不是完整的子域名部分,跳过
|
||||
continue
|
||||
}
|
||||
|
||||
result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain)
|
||||
// 检查子域名匹配
|
||||
matched := false
|
||||
// 遍历威胁数据库查找匹配的顶级域名规则
|
||||
s.domainInfoManager.GetThreatInfo("") // 这个调用会触发遍历,需要重新实现
|
||||
|
||||
// 简单的子域名匹配逻辑
|
||||
parts := strings.Split(domain, ".")
|
||||
for i := range parts {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
subDomain := strings.Join(parts[i:], ".")
|
||||
threatInfo = s.domainInfoManager.GetThreatInfo(subDomain)
|
||||
if threatInfo != nil {
|
||||
result["isThreat"] = true
|
||||
result["data"] = threatInfo
|
||||
result["threatType"] = threatInfo["type"]
|
||||
result["threatName"] = threatInfo["name"]
|
||||
result["riskLevel"] = threatInfo["riskLevel"]
|
||||
result["matchedDomain"] = subDomain
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
result["isThreat"] = false
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if result == "" {
|
||||
// 未找到匹配的威胁信息
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "无"})
|
||||
} else {
|
||||
// 返回威胁信息
|
||||
json.NewEncoder(w).Encode(map[string]string{"data": result})
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleThreatBatch 批量查询威胁域名
|
||||
|
||||
Reference in New Issue
Block a user