This commit is contained in:
Alex Yang
2026-03-30 01:04:46 +08:00
parent 050aa421b1
commit f627244b8f
5978 changed files with 1502187 additions and 2947 deletions
+698 -19
View File
@@ -1,10 +1,12 @@
package http
import (
"bytes"
"encoding/csv"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"sort"
"strings"
"sync"
@@ -12,6 +14,7 @@ import (
"dns-server/config"
"dns-server/dns"
"dns-server/domain"
"dns-server/gfw"
"dns-server/logger"
"dns-server/shield"
@@ -118,7 +121,10 @@ func (s *Server) Start() error {
}))
mux.HandleFunc("/api/shield/hosts", s.loginRequired(s.handleShieldHosts))
mux.HandleFunc("/api/shield/blacklists", s.loginRequired(s.handleShieldBlacklists))
// 传统查询接口(保持向后兼容)
mux.HandleFunc("/api/query", s.loginRequired(s.handleQuery))
// RESTful 域名查询接口
mux.HandleFunc("/api/domains/", s.loginRequired(s.handleDomainQuery))
mux.HandleFunc("/api/status", s.loginRequired(s.handleStatus))
mux.HandleFunc("/api/config", s.loginRequired(s.handleConfig))
mux.HandleFunc("/api/config/restart", s.loginRequired(s.handleRestart))
@@ -136,7 +142,15 @@ func (s *Server) Start() error {
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
// WebSocket端点
// 域名查询相关接口
mux.HandleFunc("/api/domain/info", s.loginRequired(s.handleDomainInfo))
// 域名信息列表接口
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
// 威胁查询接口
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
// 威胁批量查询接口
mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch))
// WebSocket 端点
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
// 将/api/下的静态文件服务指向static/api目录,放在最后以避免覆盖API端点
@@ -1165,7 +1179,7 @@ func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
}
}
// handleQuery 处理DNS查询请求
// handleQuery 处理DNS查询请求(传统接口,保持向后兼容)
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@@ -1174,7 +1188,9 @@ func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
domain := r.URL.Query().Get("domain")
if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
return
}
@@ -1188,6 +1204,47 @@ func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(blockDetails)
}
// handleDomainQuery 处理RESTful风格的域名查询请求
func (s *Server) handleDomainQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 从URL路径中提取域名参数
// 路径格式: /api/domains/{domain}
path := r.URL.Path
parts := strings.Split(path, "/")
if len(parts) < 4 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
return
}
domain := parts[3]
if domain == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
return
}
// 获取域名屏蔽的详细信息
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
// 构建RESTful风格的响应
response := map[string]interface{}{
"domain": domain,
"status": blockDetails["blocked"],
"timestamp": time.Now(),
"details": blockDetails,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleStatus 处理系统状态请求
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
@@ -1227,7 +1284,7 @@ func saveConfigToFile(config *config.Config, filePath string) error {
if err != nil {
return err
}
return ioutil.WriteFile(filePath, data, 0644)
return os.WriteFile(filePath, data, 0644)
}
// handleConfig 处理配置请求
@@ -1259,6 +1316,9 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
"CacheSize": s.globalConfig.DNS.CacheSize,
"MaxCacheTTL": s.globalConfig.DNS.MaxCacheTTL,
"MinCacheTTL": s.globalConfig.DNS.MinCacheTTL,
"enableFastReturn": s.globalConfig.DNS.EnableFastReturn,
"domainSpecificDNS": s.globalConfig.DNS.DomainSpecificDNS,
"noDNSSECDomains": s.globalConfig.DNS.NoDNSSECDomains,
},
"HTTPServer": map[string]interface{}{
"port": s.globalConfig.HTTP.Port,
@@ -1270,17 +1330,20 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
// 更新配置
var req struct {
DNSServer struct {
Port int `json:"port"`
QueryMode string `json:"queryMode"`
UpstreamServers []string `json:"upstreamServers"`
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
Timeout int `json:"timeout"`
SaveInterval int `json:"saveInterval"`
EnableIPv6 bool `json:"enableIPv6"`
CacheMode string `json:"cacheMode"`
CacheSize int `json:"cacheSize"`
MaxCacheTTL int `json:"maxCacheTTL"`
MinCacheTTL int `json:"minCacheTTL"`
Port int `json:"port"`
QueryMode string `json:"queryMode"`
UpstreamServers []string `json:"upstreamServers"`
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
Timeout int `json:"timeout"`
SaveInterval int `json:"saveInterval"`
EnableIPv6 bool `json:"enableIPv6"`
CacheMode string `json:"cacheMode"`
CacheSize int `json:"cacheSize"`
MaxCacheTTL int `json:"maxCacheTTL"`
MinCacheTTL int `json:"minCacheTTL"`
EnableFastReturn *bool `json:"enableFastReturn"`
DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"`
NoDNSSECDomains []string `json:"noDNSSECDomains"`
} `json:"dnsserver"`
HTTPServer struct {
Port int `json:"port"`
@@ -1333,6 +1396,18 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
if req.DNSServer.MinCacheTTL > 0 {
s.globalConfig.DNS.MinCacheTTL = req.DNSServer.MinCacheTTL
}
// 更新enableFastReturn
if req.DNSServer.EnableFastReturn != nil {
s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn
}
// 更新domainSpecificDNS
if req.DNSServer.DomainSpecificDNS != nil {
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
}
// 更新noDNSSECDomains
if len(req.DNSServer.NoDNSSECDomains) > 0 {
s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains
}
// 更新HTTP配置
if req.HTTPServer.Port > 0 {
@@ -1514,6 +1589,7 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
sortDirection := r.URL.Query().Get("direction")
resultFilter := r.URL.Query().Get("result")
searchTerm := r.URL.Query().Get("search")
queryType := r.URL.Query().Get("queryType")
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
fmt.Sscanf(limitStr, "%d", &limit)
@@ -1524,7 +1600,7 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
}
// 获取日志数据
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm)
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logs)
@@ -1537,13 +1613,52 @@ func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
return
}
// 获取日志总
count := s.dnsServer.GetQueryLogsCount()
// 获取过滤参
resultFilter := r.URL.Query().Get("result")
searchTerm := r.URL.Query().Get("search")
queryType := r.URL.Query().Get("queryType")
// 获取带过滤条件的日志总数
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
}
// handleDomainInfo 处理域名信息查询请求
func (s *Server) handleDomainInfo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
// 从域名信息数据库中查询
domainInfo, err := domain.GetDomainInfo(req.Domain)
if err != nil {
http.Error(w, "Failed to query domain info", http.StatusInternalServerError)
return
}
// 返回域名信息
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(domainInfo)
}
// handleRestart 处理重启服务请求
func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -1664,6 +1779,319 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
logger.Info("用户注销成功")
}
// handleDomainInfoList 处理域名信息列表请求
func (s *Server) handleDomainInfoList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取查询参数
query := r.URL.Query()
if query.Has("domains") {
// 处理域名信息,支持过滤特定域名
domainFilter := query.Get("domains")
handleDomainsInfo(w, domainFilter)
} else if query.Has("trackers") {
// 处理跟踪器信息,支持过滤特定域名
trackerFilter := query.Get("trackers")
handleTrackersInfo(w, trackerFilter)
} else if query.Has("threats") {
// 处理威胁域名信息,支持过滤特定域名
threatFilter := query.Get("threats")
handleThreatsInfo(w, threatFilter)
} else {
// 直接访问 /domain-info 不提供任何内容
http.Error(w, "No content provided", http.StatusNoContent)
return
}
}
// isService 判断一个对象是否是服务(而不是分组)
func isService(obj map[string]interface{}) bool {
// 服务通常包含 name、url、categoryId 字段
_, hasName := obj["name"]
_, hasUrl := obj["url"]
_, hasCategoryId := obj["categoryId"]
// 如果有 name 和 url,则认为是服务
if hasName && hasUrl {
return true
}
// 如果有 categoryId,也认为是服务
if hasCategoryId {
return true
}
return false
}
// processServiceItem 递归处理服务或分组
func processServiceItem(
serviceName string,
service interface{},
companyLevelCompany string,
domainFilter string,
categories map[string]string,
result *[]map[string]interface{},
) {
serviceMap, ok := service.(map[string]interface{})
if !ok {
return
}
// 跳过 company 字段
if serviceName == "company" {
return
}
// 判断是服务还是分组
if isService(serviceMap) {
// 这是一个服务,进行处理
urlValue := serviceMap["url"]
match := false
// 检查是否需要过滤
if domainFilter != "" {
// 检查服务名称是否包含过滤条件
if serviceName == domainFilter {
match = true
} else {
// 检查 URL 是否包含过滤条件
switch v := urlValue.(type) {
case string:
if strings.Contains(v, domainFilter) {
match = true
}
case map[string]interface{}:
for _, url := range v {
if urlStr, ok := url.(string); ok && strings.Contains(urlStr, domainFilter) {
match = true
break
}
}
}
}
if !match {
return
}
}
// 确定公司名:优先使用服务级别的 company 字段,否则使用公司级别的 company 字段
itemCompany := companyLevelCompany
if serviceCompany, ok := serviceMap["company"].(string); ok {
itemCompany = serviceCompany
}
// 构建响应对象
item := map[string]interface{}{
"icon": serviceMap["icon"],
"name": serviceMap["name"],
"company": itemCompany,
}
// 添加类别
if categoryId, ok := serviceMap["categoryId"].(float64); ok {
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
if category, exists := categories[categoryIdStr]; exists {
item["category"] = category
}
}
*result = append(*result, item)
} else {
// 这是一个分组,递归处理其下的子项
for subName, subService := range serviceMap {
processServiceItem(subName, subService, companyLevelCompany, domainFilter, categories, result)
}
}
}
func handleDomainsInfo(w http.ResponseWriter, domainFilter string) {
// 如果过滤参数为空字符串,返回空数组
if domainFilter == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]interface{}{})
return
}
filePath := "./static/domain-info/domains/domain-info.json"
data, err := os.ReadFile(filePath)
if err != nil {
http.Error(w, "Failed to read domain info file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("读取域名信息文件失败: %v", err))
return
}
// 解析JSON
var domainInfo struct {
Categories map[string]string `json:"categories"`
Domains map[string]map[string]interface{} `json:"domains"`
}
if err := json.Unmarshal(data, &domainInfo); err != nil {
http.Error(w, "Failed to parse domain info file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("解析域名信息文件失败: %v", err))
return
}
// 转换为所需格式
var result []map[string]interface{}
for _, services := range domainInfo.Domains {
// 获取公司级别的 company 字段
companyLevelCompany := ""
if companyData, ok := services["company"].(string); ok {
companyLevelCompany = companyData
}
// 遍历所有服务(包括嵌套的分组)
for serviceName, service := range services {
processServiceItem(serviceName, service, companyLevelCompany, domainFilter, domainInfo.Categories, &result)
}
}
// 返回 JSON 响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTrackersInfo 处理跟踪器信息请求,返回名称、类别、url、所属单位/公司
func handleTrackersInfo(w http.ResponseWriter, trackerFilter string) {
// 如果过滤参数为空字符串,返回空数组
if trackerFilter == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]interface{}{})
return
}
filePath := "./static/domain-info/tracker/trackers.json"
data, err := os.ReadFile(filePath)
if err != nil {
http.Error(w, "Failed to read trackers file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("读取跟踪器文件失败: %v", err))
return
}
// 解析JSON
var trackersInfo struct {
Categories map[string]string `json:"categories"`
Trackers map[string]map[string]interface{} `json:"trackers"`
}
if err := json.Unmarshal(data, &trackersInfo); err != nil {
http.Error(w, "Failed to parse trackers file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("解析跟踪器文件失败: %v", err))
return
}
// 转换为所需格式
var result []map[string]interface{}
for trackerDomain, tracker := range trackersInfo.Trackers {
// 检查是否需要过滤
if trackerFilter != "" {
// 检查跟踪器域名是否包含过滤条件
if !strings.Contains(trackerDomain, trackerFilter) {
// 检查名称是否包含过滤条件
if name, ok := tracker["name"].(string); !ok || !strings.Contains(name, trackerFilter) {
// 检查URL是否包含过滤条件
if url, ok := tracker["url"].(string); !ok || !strings.Contains(url, trackerFilter) {
continue
}
}
}
}
item := map[string]interface{}{
"name": tracker["name"],
"url": tracker["url"],
"company": tracker["companyId"],
}
// 添加类别
if categoryId, ok := tracker["categoryId"].(float64); ok {
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
if category, exists := trackersInfo.Categories[categoryIdStr]; exists {
item["category"] = category
}
}
result = append(result, item)
}
// 返回JSON响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleThreatsInfo 处理威胁域名信息请求,返回类型、名称、级别、域名
func handleThreatsInfo(w http.ResponseWriter, threatFilter string) {
// 如果过滤参数为空字符串,返回空数组
if threatFilter == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]string{})
return
}
filePath := "./static/domain-info/threats/threats-database.csv"
data, err := os.ReadFile(filePath)
if err != nil {
http.Error(w, "Failed to read threats file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("读取威胁域名文件失败: %v", err))
return
}
// 解析CSV
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1 // 允许不同长度的记录
// 读取所有记录
records, err := reader.ReadAll()
if err != nil {
http.Error(w, "Failed to parse threats file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("解析威胁域名文件失败: %v", err))
return
}
// 转换为所需格式
var result []map[string]string
// 跳过标题行
for i, record := range records {
if i == 0 {
continue
}
if len(record) >= 4 {
// 检查是否需要过滤
if threatFilter != "" {
// 检查域名是否包含过滤条件
if !strings.Contains(record[3], threatFilter) {
// 检查名称是否包含过滤条件
if !strings.Contains(record[1], threatFilter) {
// 检查类型是否包含过滤条件
if !strings.Contains(record[0], threatFilter) {
continue
}
}
}
}
item := map[string]string{
"type": record[0],
"name": record[1],
"level": record[2],
"domain": record[3],
}
result = append(result, item)
}
}
// 返回JSON响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleChangePassword 处理修改密码请求
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -1709,3 +2137,254 @@ func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "密码修改成功"})
logger.Info("密码修改成功")
}
// handleThreatQuery 处理威胁域名查询请求
// @Summary 查询威胁域名信息
// @Description 根据传入的域名参数查询威胁数据库,返回威胁类型、名称、风险等级和域名
// @Tags threat
// @Accept json
// @Produce json
// @Param domain query string true "要查询的域名"
// @Success 200 {string} string "威胁信息,格式:类型,名称,风险等级,域名"
// @Failure 400 {object} map[string]string "缺少域名参数"
// @Router /api/threat [get]
func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取域名参数
domain := r.URL.Query().Get("domain")
if domain == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供 domain 参数"})
return
}
// 读取威胁数据库 CSV 文件
filePath := "./static/domain-info/threats/threats-database.csv"
data, err := os.ReadFile(filePath)
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
return
}
// 解析 CSV
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1 // 允许不同长度的记录
// 读取所有记录
records, err := reader.ReadAll()
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
return
}
// 构建威胁域名映射(支持顶级域名匹配)
threatMap := make(map[string][]string)
for i, record := range records {
if i == 0 {
continue // 跳过标题行
}
if len(record) >= 4 {
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatInfo := []string{threatType, threatName, riskLevel}
// 1. 完整域名匹配(所有类型都添加)
threatMap[domain] = threatInfo
// 2. 只有恶意网站类型才添加子域名匹配规则
// 类型判断:钓鱼网站、仿冒网站
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
// 对于恶意网站,添加子域名匹配规则
// 例如:sub.example.com -> 添加 .sub.example.com 规则
// 这样 a.sub.example.com 就会匹配
topLevelDomain := "." + domain
// 只有当该顶级域名规则不存在时才添加
if _, exists := threatMap[topLevelDomain]; !exists {
threatMap[topLevelDomain] = threatInfo
}
}
}
}
// 查询单个域名
var result string
// 1. 先检查完整匹配
if threat, exists := threatMap[domain]; exists {
result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain)
} else {
// 2. 检查子域名匹配(遍历顶级域名规则)
for threatDomain, threatInfo := range threatMap {
// 只检查以点开头的顶级域名规则
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
// 额外验证:确保是完整的域名部分匹配
prefix := strings.TrimSuffix(domain, threatDomain)
if len(prefix) > 0 && !strings.HasSuffix(prefix, ".") {
// 不是完整的子域名部分,跳过
continue
}
result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain)
break
}
}
}
w.Header().Set("Content-Type", "application/json")
if result == "" {
// 未找到匹配的威胁信息
json.NewEncoder(w).Encode(map[string]string{"message": "无"})
} else {
// 返回威胁信息
json.NewEncoder(w).Encode(map[string]string{"data": result})
}
}
// handleThreatBatch 批量查询威胁域名
// @Summary 批量查询威胁域名
// @Description 批量查询多个域名是否是威胁域名
// @Tags threat
// @Accept json
// @Produce json
// @Param domains body []string true "域名列表"
// @Success 200 {object} map[string]interface{} "批量查询结果"
// @Router /api/threat/batch [post]
func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Domains []string `json:"domains"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "请求格式错误"})
return
}
// 读取威胁数据库 CSV 文件
filePath := "./static/domain-info/threats/threats-database.csv"
data, err := os.ReadFile(filePath)
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
return
}
// 解析 CSV
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1 // 允许不同长度的记录
// 读取所有记录
records, err := reader.ReadAll()
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
return
}
// 构建威胁域名映射(支持顶级域名匹配)
threatMap := make(map[string][]string)
for i, record := range records {
if i == 0 {
continue // 跳过标题行
}
if len(record) >= 4 {
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatInfo := []string{threatType, threatName, riskLevel}
// 1. 完整域名匹配(所有类型都添加)
threatMap[domain] = threatInfo
// 2. 只有恶意网站类型才添加子域名匹配规则
// 类型判断:钓鱼网站、仿冒网站
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
// 对于恶意网站,添加子域名匹配规则
// 例如:sub.example.com -> 添加 .sub.example.com 规则
// 这样 a.sub.example.com 就会匹配
topLevelDomain := "." + domain
// 只有当该顶级域名规则不存在时才添加
if _, exists := threatMap[topLevelDomain]; !exists {
threatMap[topLevelDomain] = threatInfo
}
}
}
}
// 批量查询
results := make([]map[string]interface{}, 0, len(req.Domains))
for _, domain := range req.Domains {
// 1. 先检查完整匹配
if threat, exists := threatMap[domain]; exists {
results = append(results, map[string]interface{}{
"domain": domain,
"isThreat": true,
"data": fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain),
})
continue
}
// 2. 检查子域名匹配(遍历顶级域名规则)
matched := false
for threatDomain, threatInfo := range threatMap {
// 只检查以点开头的顶级域名规则
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
// 验证:确保是有效的子域名匹配
// 例如:test.example.com 匹配 .example.com ✅
// notexample.com 不应该匹配 .example.com ❌
// 去掉 threatDomain 的第一个字符(即去掉开头的点)
suffixToTrim := threatDomain[1:]
prefix := strings.TrimSuffix(domain, suffixToTrim)
// 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配)
if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) {
results = append(results, map[string]interface{}{
"domain": domain,
"isThreat": true,
"data": fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain),
})
matched = true
break
}
}
}
if !matched {
results = append(results, map[string]interface{}{
"domain": domain,
"isThreat": false,
})
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"results": results,
})
}