更新
This commit is contained in:
+698
-19
@@ -1,10 +1,12 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
|
||||
"dns-server/config"
|
||||
"dns-server/dns"
|
||||
"dns-server/domain"
|
||||
"dns-server/gfw"
|
||||
"dns-server/logger"
|
||||
"dns-server/shield"
|
||||
@@ -118,7 +121,10 @@ func (s *Server) Start() error {
|
||||
}))
|
||||
mux.HandleFunc("/api/shield/hosts", s.loginRequired(s.handleShieldHosts))
|
||||
mux.HandleFunc("/api/shield/blacklists", s.loginRequired(s.handleShieldBlacklists))
|
||||
// 传统查询接口(保持向后兼容)
|
||||
mux.HandleFunc("/api/query", s.loginRequired(s.handleQuery))
|
||||
// RESTful 域名查询接口
|
||||
mux.HandleFunc("/api/domains/", s.loginRequired(s.handleDomainQuery))
|
||||
mux.HandleFunc("/api/status", s.loginRequired(s.handleStatus))
|
||||
mux.HandleFunc("/api/config", s.loginRequired(s.handleConfig))
|
||||
mux.HandleFunc("/api/config/restart", s.loginRequired(s.handleRestart))
|
||||
@@ -136,7 +142,15 @@ func (s *Server) Start() error {
|
||||
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
|
||||
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
|
||||
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
|
||||
// WebSocket端点
|
||||
// 域名查询相关接口
|
||||
mux.HandleFunc("/api/domain/info", s.loginRequired(s.handleDomainInfo))
|
||||
// 域名信息列表接口
|
||||
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
|
||||
// 威胁查询接口
|
||||
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
|
||||
// 威胁批量查询接口
|
||||
mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch))
|
||||
// WebSocket 端点
|
||||
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
|
||||
|
||||
// 将/api/下的静态文件服务指向static/api目录,放在最后以避免覆盖API端点
|
||||
@@ -1165,7 +1179,7 @@ func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleQuery 处理DNS查询请求
|
||||
// handleQuery 处理DNS查询请求(传统接口,保持向后兼容)
|
||||
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
@@ -1174,7 +1188,9 @@ func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
domain := r.URL.Query().Get("domain")
|
||||
if domain == "" {
|
||||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1188,6 +1204,47 @@ func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(blockDetails)
|
||||
}
|
||||
|
||||
// handleDomainQuery 处理RESTful风格的域名查询请求
|
||||
func (s *Server) handleDomainQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 从URL路径中提取域名参数
|
||||
// 路径格式: /api/domains/{domain}
|
||||
path := r.URL.Path
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 4 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
|
||||
return
|
||||
}
|
||||
|
||||
domain := parts[3]
|
||||
if domain == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取域名屏蔽的详细信息
|
||||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||||
|
||||
// 构建RESTful风格的响应
|
||||
response := map[string]interface{}{
|
||||
"domain": domain,
|
||||
"status": blockDetails["blocked"],
|
||||
"timestamp": time.Now(),
|
||||
"details": blockDetails,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleStatus 处理系统状态请求
|
||||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
@@ -1227,7 +1284,7 @@ func saveConfigToFile(config *config.Config, filePath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(filePath, data, 0644)
|
||||
return os.WriteFile(filePath, data, 0644)
|
||||
}
|
||||
|
||||
// handleConfig 处理配置请求
|
||||
@@ -1259,6 +1316,9 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
"CacheSize": s.globalConfig.DNS.CacheSize,
|
||||
"MaxCacheTTL": s.globalConfig.DNS.MaxCacheTTL,
|
||||
"MinCacheTTL": s.globalConfig.DNS.MinCacheTTL,
|
||||
"enableFastReturn": s.globalConfig.DNS.EnableFastReturn,
|
||||
"domainSpecificDNS": s.globalConfig.DNS.DomainSpecificDNS,
|
||||
"noDNSSECDomains": s.globalConfig.DNS.NoDNSSECDomains,
|
||||
},
|
||||
"HTTPServer": map[string]interface{}{
|
||||
"port": s.globalConfig.HTTP.Port,
|
||||
@@ -1270,17 +1330,20 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// 更新配置
|
||||
var req struct {
|
||||
DNSServer struct {
|
||||
Port int `json:"port"`
|
||||
QueryMode string `json:"queryMode"`
|
||||
UpstreamServers []string `json:"upstreamServers"`
|
||||
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
|
||||
Timeout int `json:"timeout"`
|
||||
SaveInterval int `json:"saveInterval"`
|
||||
EnableIPv6 bool `json:"enableIPv6"`
|
||||
CacheMode string `json:"cacheMode"`
|
||||
CacheSize int `json:"cacheSize"`
|
||||
MaxCacheTTL int `json:"maxCacheTTL"`
|
||||
MinCacheTTL int `json:"minCacheTTL"`
|
||||
Port int `json:"port"`
|
||||
QueryMode string `json:"queryMode"`
|
||||
UpstreamServers []string `json:"upstreamServers"`
|
||||
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
|
||||
Timeout int `json:"timeout"`
|
||||
SaveInterval int `json:"saveInterval"`
|
||||
EnableIPv6 bool `json:"enableIPv6"`
|
||||
CacheMode string `json:"cacheMode"`
|
||||
CacheSize int `json:"cacheSize"`
|
||||
MaxCacheTTL int `json:"maxCacheTTL"`
|
||||
MinCacheTTL int `json:"minCacheTTL"`
|
||||
EnableFastReturn *bool `json:"enableFastReturn"`
|
||||
DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"`
|
||||
NoDNSSECDomains []string `json:"noDNSSECDomains"`
|
||||
} `json:"dnsserver"`
|
||||
HTTPServer struct {
|
||||
Port int `json:"port"`
|
||||
@@ -1333,6 +1396,18 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if req.DNSServer.MinCacheTTL > 0 {
|
||||
s.globalConfig.DNS.MinCacheTTL = req.DNSServer.MinCacheTTL
|
||||
}
|
||||
// 更新enableFastReturn
|
||||
if req.DNSServer.EnableFastReturn != nil {
|
||||
s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn
|
||||
}
|
||||
// 更新domainSpecificDNS
|
||||
if req.DNSServer.DomainSpecificDNS != nil {
|
||||
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
|
||||
}
|
||||
// 更新noDNSSECDomains
|
||||
if len(req.DNSServer.NoDNSSECDomains) > 0 {
|
||||
s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains
|
||||
}
|
||||
|
||||
// 更新HTTP配置
|
||||
if req.HTTPServer.Port > 0 {
|
||||
@@ -1514,6 +1589,7 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
sortDirection := r.URL.Query().Get("direction")
|
||||
resultFilter := r.URL.Query().Get("result")
|
||||
searchTerm := r.URL.Query().Get("search")
|
||||
queryType := r.URL.Query().Get("queryType")
|
||||
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
fmt.Sscanf(limitStr, "%d", &limit)
|
||||
@@ -1524,7 +1600,7 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 获取日志数据
|
||||
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm)
|
||||
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(logs)
|
||||
@@ -1537,13 +1613,52 @@ func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取日志总数
|
||||
count := s.dnsServer.GetQueryLogsCount()
|
||||
// 获取过滤参数
|
||||
resultFilter := r.URL.Query().Get("result")
|
||||
searchTerm := r.URL.Query().Get("search")
|
||||
queryType := r.URL.Query().Get("queryType")
|
||||
|
||||
// 获取带过滤条件的日志总数
|
||||
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]int{"count": count})
|
||||
}
|
||||
|
||||
// handleDomainInfo 处理域名信息查询请求
|
||||
func (s *Server) handleDomainInfo(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Domain == "" {
|
||||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 从域名信息数据库中查询
|
||||
domainInfo, err := domain.GetDomainInfo(req.Domain)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to query domain info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回域名信息
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(domainInfo)
|
||||
}
|
||||
|
||||
// handleRestart 处理重启服务请求
|
||||
func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -1664,6 +1779,319 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info("用户注销成功")
|
||||
}
|
||||
|
||||
// handleDomainInfoList 处理域名信息列表请求
|
||||
func (s *Server) handleDomainInfoList(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
query := r.URL.Query()
|
||||
|
||||
if query.Has("domains") {
|
||||
// 处理域名信息,支持过滤特定域名
|
||||
domainFilter := query.Get("domains")
|
||||
handleDomainsInfo(w, domainFilter)
|
||||
} else if query.Has("trackers") {
|
||||
// 处理跟踪器信息,支持过滤特定域名
|
||||
trackerFilter := query.Get("trackers")
|
||||
handleTrackersInfo(w, trackerFilter)
|
||||
} else if query.Has("threats") {
|
||||
// 处理威胁域名信息,支持过滤特定域名
|
||||
threatFilter := query.Get("threats")
|
||||
handleThreatsInfo(w, threatFilter)
|
||||
} else {
|
||||
// 直接访问 /domain-info 不提供任何内容
|
||||
http.Error(w, "No content provided", http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// isService 判断一个对象是否是服务(而不是分组)
|
||||
func isService(obj map[string]interface{}) bool {
|
||||
// 服务通常包含 name、url、categoryId 字段
|
||||
_, hasName := obj["name"]
|
||||
_, hasUrl := obj["url"]
|
||||
_, hasCategoryId := obj["categoryId"]
|
||||
|
||||
// 如果有 name 和 url,则认为是服务
|
||||
if hasName && hasUrl {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果有 categoryId,也认为是服务
|
||||
if hasCategoryId {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// processServiceItem 递归处理服务或分组
|
||||
func processServiceItem(
|
||||
serviceName string,
|
||||
service interface{},
|
||||
companyLevelCompany string,
|
||||
domainFilter string,
|
||||
categories map[string]string,
|
||||
result *[]map[string]interface{},
|
||||
) {
|
||||
serviceMap, ok := service.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 跳过 company 字段
|
||||
if serviceName == "company" {
|
||||
return
|
||||
}
|
||||
|
||||
// 判断是服务还是分组
|
||||
if isService(serviceMap) {
|
||||
// 这是一个服务,进行处理
|
||||
urlValue := serviceMap["url"]
|
||||
match := false
|
||||
|
||||
// 检查是否需要过滤
|
||||
if domainFilter != "" {
|
||||
// 检查服务名称是否包含过滤条件
|
||||
if serviceName == domainFilter {
|
||||
match = true
|
||||
} else {
|
||||
// 检查 URL 是否包含过滤条件
|
||||
switch v := urlValue.(type) {
|
||||
case string:
|
||||
if strings.Contains(v, domainFilter) {
|
||||
match = true
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for _, url := range v {
|
||||
if urlStr, ok := url.(string); ok && strings.Contains(urlStr, domainFilter) {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !match {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 确定公司名:优先使用服务级别的 company 字段,否则使用公司级别的 company 字段
|
||||
itemCompany := companyLevelCompany
|
||||
if serviceCompany, ok := serviceMap["company"].(string); ok {
|
||||
itemCompany = serviceCompany
|
||||
}
|
||||
|
||||
// 构建响应对象
|
||||
item := map[string]interface{}{
|
||||
"icon": serviceMap["icon"],
|
||||
"name": serviceMap["name"],
|
||||
"company": itemCompany,
|
||||
}
|
||||
|
||||
// 添加类别
|
||||
if categoryId, ok := serviceMap["categoryId"].(float64); ok {
|
||||
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
|
||||
if category, exists := categories[categoryIdStr]; exists {
|
||||
item["category"] = category
|
||||
}
|
||||
}
|
||||
|
||||
*result = append(*result, item)
|
||||
} else {
|
||||
// 这是一个分组,递归处理其下的子项
|
||||
for subName, subService := range serviceMap {
|
||||
processServiceItem(subName, subService, companyLevelCompany, domainFilter, categories, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
func handleDomainsInfo(w http.ResponseWriter, domainFilter string) {
|
||||
// 如果过滤参数为空字符串,返回空数组
|
||||
if domainFilter == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
filePath := "./static/domain-info/domains/domain-info.json"
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read domain info file", http.StatusInternalServerError)
|
||||
logger.Error(fmt.Sprintf("读取域名信息文件失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var domainInfo struct {
|
||||
Categories map[string]string `json:"categories"`
|
||||
Domains map[string]map[string]interface{} `json:"domains"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &domainInfo); err != nil {
|
||||
http.Error(w, "Failed to parse domain info file", http.StatusInternalServerError)
|
||||
logger.Error(fmt.Sprintf("解析域名信息文件失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为所需格式
|
||||
var result []map[string]interface{}
|
||||
for _, services := range domainInfo.Domains {
|
||||
// 获取公司级别的 company 字段
|
||||
companyLevelCompany := ""
|
||||
if companyData, ok := services["company"].(string); ok {
|
||||
companyLevelCompany = companyData
|
||||
}
|
||||
|
||||
// 遍历所有服务(包括嵌套的分组)
|
||||
for serviceName, service := range services {
|
||||
processServiceItem(serviceName, service, companyLevelCompany, domainFilter, domainInfo.Categories, &result)
|
||||
}
|
||||
}
|
||||
|
||||
// 返回 JSON 响应
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleTrackersInfo 处理跟踪器信息请求,返回名称、类别、url、所属单位/公司
|
||||
func handleTrackersInfo(w http.ResponseWriter, trackerFilter string) {
|
||||
// 如果过滤参数为空字符串,返回空数组
|
||||
if trackerFilter == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
filePath := "./static/domain-info/tracker/trackers.json"
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read trackers file", http.StatusInternalServerError)
|
||||
logger.Error(fmt.Sprintf("读取跟踪器文件失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var trackersInfo struct {
|
||||
Categories map[string]string `json:"categories"`
|
||||
Trackers map[string]map[string]interface{} `json:"trackers"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &trackersInfo); err != nil {
|
||||
http.Error(w, "Failed to parse trackers file", http.StatusInternalServerError)
|
||||
logger.Error(fmt.Sprintf("解析跟踪器文件失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为所需格式
|
||||
var result []map[string]interface{}
|
||||
for trackerDomain, tracker := range trackersInfo.Trackers {
|
||||
// 检查是否需要过滤
|
||||
if trackerFilter != "" {
|
||||
// 检查跟踪器域名是否包含过滤条件
|
||||
if !strings.Contains(trackerDomain, trackerFilter) {
|
||||
// 检查名称是否包含过滤条件
|
||||
if name, ok := tracker["name"].(string); !ok || !strings.Contains(name, trackerFilter) {
|
||||
// 检查URL是否包含过滤条件
|
||||
if url, ok := tracker["url"].(string); !ok || !strings.Contains(url, trackerFilter) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
item := map[string]interface{}{
|
||||
"name": tracker["name"],
|
||||
"url": tracker["url"],
|
||||
"company": tracker["companyId"],
|
||||
}
|
||||
|
||||
// 添加类别
|
||||
if categoryId, ok := tracker["categoryId"].(float64); ok {
|
||||
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
|
||||
if category, exists := trackersInfo.Categories[categoryIdStr]; exists {
|
||||
item["category"] = category
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
// 返回JSON响应
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleThreatsInfo 处理威胁域名信息请求,返回类型、名称、级别、域名
|
||||
func handleThreatsInfo(w http.ResponseWriter, threatFilter string) {
|
||||
// 如果过滤参数为空字符串,返回空数组
|
||||
if threatFilter == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode([]map[string]string{})
|
||||
return
|
||||
}
|
||||
|
||||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read threats file", http.StatusInternalServerError)
|
||||
logger.Error(fmt.Sprintf("读取威胁域名文件失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析CSV
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||||
|
||||
// 读取所有记录
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to parse threats file", http.StatusInternalServerError)
|
||||
logger.Error(fmt.Sprintf("解析威胁域名文件失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为所需格式
|
||||
var result []map[string]string
|
||||
// 跳过标题行
|
||||
for i, record := range records {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(record) >= 4 {
|
||||
// 检查是否需要过滤
|
||||
if threatFilter != "" {
|
||||
// 检查域名是否包含过滤条件
|
||||
if !strings.Contains(record[3], threatFilter) {
|
||||
// 检查名称是否包含过滤条件
|
||||
if !strings.Contains(record[1], threatFilter) {
|
||||
// 检查类型是否包含过滤条件
|
||||
if !strings.Contains(record[0], threatFilter) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
item := map[string]string{
|
||||
"type": record[0],
|
||||
"name": record[1],
|
||||
"level": record[2],
|
||||
"domain": record[3],
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
||||
// 返回JSON响应
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleChangePassword 处理修改密码请求
|
||||
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -1709,3 +2137,254 @@ func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "密码修改成功"})
|
||||
logger.Info("密码修改成功")
|
||||
}
|
||||
|
||||
// handleThreatQuery 处理威胁域名查询请求
|
||||
// @Summary 查询威胁域名信息
|
||||
// @Description 根据传入的域名参数查询威胁数据库,返回威胁类型、名称、风险等级和域名
|
||||
// @Tags threat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param domain query string true "要查询的域名"
|
||||
// @Success 200 {string} string "威胁信息,格式:类型,名称,风险等级,域名"
|
||||
// @Failure 400 {object} map[string]string "缺少域名参数"
|
||||
// @Router /api/threat [get]
|
||||
func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取域名参数
|
||||
domain := r.URL.Query().Get("domain")
|
||||
if domain == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供 domain 参数"})
|
||||
return
|
||||
}
|
||||
|
||||
// 读取威胁数据库 CSV 文件
|
||||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
|
||||
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 CSV
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||||
|
||||
// 读取所有记录
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
|
||||
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 构建威胁域名映射(支持顶级域名匹配)
|
||||
threatMap := make(map[string][]string)
|
||||
for i, record := range records {
|
||||
if i == 0 {
|
||||
continue // 跳过标题行
|
||||
}
|
||||
if len(record) >= 4 {
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatInfo := []string{threatType, threatName, riskLevel}
|
||||
|
||||
// 1. 完整域名匹配(所有类型都添加)
|
||||
threatMap[domain] = threatInfo
|
||||
|
||||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||||
// 类型判断:钓鱼网站、仿冒网站
|
||||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||||
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
|
||||
// 对于恶意网站,添加子域名匹配规则
|
||||
// 例如:sub.example.com -> 添加 .sub.example.com 规则
|
||||
// 这样 a.sub.example.com 就会匹配
|
||||
topLevelDomain := "." + domain
|
||||
// 只有当该顶级域名规则不存在时才添加
|
||||
if _, exists := threatMap[topLevelDomain]; !exists {
|
||||
threatMap[topLevelDomain] = threatInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询单个域名
|
||||
var result string
|
||||
|
||||
// 1. 先检查完整匹配
|
||||
if threat, exists := threatMap[domain]; exists {
|
||||
result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain)
|
||||
} else {
|
||||
// 2. 检查子域名匹配(遍历顶级域名规则)
|
||||
for threatDomain, threatInfo := range threatMap {
|
||||
// 只检查以点开头的顶级域名规则
|
||||
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
|
||||
// 额外验证:确保是完整的域名部分匹配
|
||||
prefix := strings.TrimSuffix(domain, threatDomain)
|
||||
if len(prefix) > 0 && !strings.HasSuffix(prefix, ".") {
|
||||
// 不是完整的子域名部分,跳过
|
||||
continue
|
||||
}
|
||||
|
||||
result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if result == "" {
|
||||
// 未找到匹配的威胁信息
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "无"})
|
||||
} else {
|
||||
// 返回威胁信息
|
||||
json.NewEncoder(w).Encode(map[string]string{"data": result})
|
||||
}
|
||||
}
|
||||
|
||||
// handleThreatBatch 批量查询威胁域名
|
||||
// @Summary 批量查询威胁域名
|
||||
// @Description 批量查询多个域名是否是威胁域名
|
||||
// @Tags threat
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param domains body []string true "域名列表"
|
||||
// @Success 200 {object} map[string]interface{} "批量查询结果"
|
||||
// @Router /api/threat/batch [post]
|
||||
func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Domains []string `json:"domains"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "请求格式错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 读取威胁数据库 CSV 文件
|
||||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
|
||||
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 CSV
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||||
|
||||
// 读取所有记录
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
|
||||
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 构建威胁域名映射(支持顶级域名匹配)
|
||||
threatMap := make(map[string][]string)
|
||||
for i, record := range records {
|
||||
if i == 0 {
|
||||
continue // 跳过标题行
|
||||
}
|
||||
if len(record) >= 4 {
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatInfo := []string{threatType, threatName, riskLevel}
|
||||
|
||||
// 1. 完整域名匹配(所有类型都添加)
|
||||
threatMap[domain] = threatInfo
|
||||
|
||||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||||
// 类型判断:钓鱼网站、仿冒网站
|
||||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||||
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
|
||||
// 对于恶意网站,添加子域名匹配规则
|
||||
// 例如:sub.example.com -> 添加 .sub.example.com 规则
|
||||
// 这样 a.sub.example.com 就会匹配
|
||||
topLevelDomain := "." + domain
|
||||
// 只有当该顶级域名规则不存在时才添加
|
||||
if _, exists := threatMap[topLevelDomain]; !exists {
|
||||
threatMap[topLevelDomain] = threatInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 批量查询
|
||||
results := make([]map[string]interface{}, 0, len(req.Domains))
|
||||
for _, domain := range req.Domains {
|
||||
// 1. 先检查完整匹配
|
||||
if threat, exists := threatMap[domain]; exists {
|
||||
results = append(results, map[string]interface{}{
|
||||
"domain": domain,
|
||||
"isThreat": true,
|
||||
"data": fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. 检查子域名匹配(遍历顶级域名规则)
|
||||
matched := false
|
||||
for threatDomain, threatInfo := range threatMap {
|
||||
// 只检查以点开头的顶级域名规则
|
||||
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
|
||||
// 验证:确保是有效的子域名匹配
|
||||
// 例如:test.example.com 匹配 .example.com ✅
|
||||
// notexample.com 不应该匹配 .example.com ❌
|
||||
// 去掉 threatDomain 的第一个字符(即去掉开头的点)
|
||||
suffixToTrim := threatDomain[1:]
|
||||
prefix := strings.TrimSuffix(domain, suffixToTrim)
|
||||
|
||||
// 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配)
|
||||
if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) {
|
||||
results = append(results, map[string]interface{}{
|
||||
"domain": domain,
|
||||
"isThreat": true,
|
||||
"data": fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain),
|
||||
})
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
results = append(results, map[string]interface{}{
|
||||
"domain": domain,
|
||||
"isThreat": false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user