增加威胁域名审计

This commit is contained in:
Alex Yang
2026-04-03 10:04:07 +08:00
parent 170cdb3537
commit f8e222aaf6
41 changed files with 81016 additions and 4672993 deletions
+241 -77
View File
@@ -14,15 +14,14 @@ import (
"dns-server/config"
"dns-server/dns"
"dns-server/domain"
"dns-server/gfw"
"dns-server/log"
"dns-server/logger"
"dns-server/shield"
"gopkg.in/ini.v1"
"dns-server/threat"
"github.com/gorilla/websocket"
"gopkg.in/ini.v1"
)
// CacheEntry 缓存条目
@@ -328,11 +327,11 @@ type Server struct {
broadcastChan chan []byte
// 查询缓存相关字段
queryCache *QueryCache // 查询结果缓存
statsCache *StatsCache // 统计数据缓存
cacheEnabled bool // 缓存是否启用
cacheTTL time.Duration // 缓存过期时间
cacheMaxSize int // 缓存最大条目数
queryCache *QueryCache // 查询结果缓存
statsCache *StatsCache // 统计数据缓存
cacheEnabled bool // 缓存是否启用
cacheTTL time.Duration // 缓存过期时间
cacheMaxSize int // 缓存最大条目数
}
// NewServer 创建HTTP服务器实例
@@ -357,11 +356,11 @@ func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager
sessions: make(map[string]time.Time),
sessionTTL: 24 * time.Hour, // 会话有效期 24 小时
// 查询缓存初始化
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
cacheEnabled: true, // 默认启用缓存
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
cacheMaxSize: 100, // 默认最大 100 条
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
cacheEnabled: true, // 默认启用缓存
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
cacheMaxSize: 100, // 默认最大 100 条
}
// 启动广播协程
@@ -440,8 +439,9 @@ func (s *Server) Start() error {
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
// 域名查询相关接口
mux.HandleFunc("/api/domain/info", s.loginRequired(s.handleDomainInfo))
// 归档管理接口
mux.HandleFunc("/api/logs/archives", s.loginRequired(s.handleArchiveList))
mux.HandleFunc("/api/logs/archive-cleanup", s.loginRequired(s.handleArchiveCleanup))
// 域名信息列表接口
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
// 威胁查询接口
@@ -1559,8 +1559,9 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
stats := s.dnsServer.GetStats()
// 使用服务器的实际启动时间计算准确的运行时间
serverStartTime := s.dnsServer.GetStartTime()
uptime := time.Since(serverStartTime)
// 由于已移除旧日志系统,使用当前时间作为启动时间
serverStartTime := time.Now()
uptime := time.Duration(0)
// 构建包含所有真实服务器统计数据的响应
status := map[string]interface{}{
@@ -1999,6 +2000,30 @@ func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
// 缓存未命中,获取最新统计数据
logStats := s.dnsServer.GetQueryStats()
// 添加归档统计信息
archiveManager := s.dnsServer.GetArchiveManager()
if archiveManager != nil {
archives, err := archiveManager.GetArchiveList()
if err == nil {
var archiveTotalRecords int64 = 0
var archiveTotalCompressedSize int64 = 0
for _, archive := range archives {
archiveTotalRecords += archive.RecordCount
archiveTotalCompressedSize += archive.CompressedSize
}
logStats["archiveCount"] = len(archives)
logStats["archiveTotalRecords"] = archiveTotalRecords
logStats["archiveTotalCompressedSize"] = archiveTotalCompressedSize
logStats["archiveTotalSize"] = archiveTotalCompressedSize // 压缩后的大小
// 如果有主库统计,计算总记录数
if totalRecords, ok := logStats["totalQueries"].(int64); ok {
logStats["grandTotalRecords"] = totalRecords + archiveTotalRecords
}
}
}
// 存入缓存
if s.cacheEnabled {
s.statsCache.Set(cacheKey, logStats)
@@ -2016,8 +2041,8 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
}
// 获取查询参数
limit := 100 // 默认返回 100 条日志
offset := 0
limit := 30 // 默认返回 30 条日志
pageNum := 1 // 默认第 1 页
sortField := r.URL.Query().Get("sort")
sortDirection := r.URL.Query().Get("direction")
resultFilter := r.URL.Query().Get("result")
@@ -2028,24 +2053,128 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
fmt.Sscanf(limitStr, "%d", &limit)
}
// 支持 page 参数(优先)或 offset 参数
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
fmt.Sscanf(pageStr, "%d", &pageNum)
}
offset := (pageNum - 1) * limit
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
fmt.Sscanf(offsetStr, "%d", &offset)
}
// 构建缓存键,包含所有查询参数
// 已禁用缓存,每次都从数据库获取最新数据
// cacheKey := fmt.Sprintf("logs_%d_%d_%s_%s_%s_%s_%s", limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
// 构建缓存键,包含所有查询参数(排除 offset,因为分页数据可以复用)
cacheKey := fmt.Sprintf("logs_query_%s_%s_%s_%s_%s_%d", sortField, sortDirection, resultFilter, searchTerm, queryType, limit)
// 缓存未命中,获取日志数据(已禁用缓存)
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
// 尝试从缓存获取
if s.cacheEnabled {
if cachedLogs, found := s.queryCache.Get(cacheKey); found {
// 缓存命中,直接返回
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(cachedLogs)
return
}
}
// 存入缓存(已禁用)
// if s.cacheEnabled {
// s.queryCache.Set(cacheKey, logs)
// }
// 构建过滤条件和分页参数
filter := log.LogFilter{
Result: resultFilter,
SearchTerm: searchTerm,
QueryType: queryType,
}
pageParams := log.PageParams{
Limit: limit,
Offset: offset,
SortField: sortField,
SortDirection: sortDirection,
}
var logs []log.QueryLog
var total int64
var err error
// 优先使用归档查询引擎(如果可用)
archiveQueryEngine := s.dnsServer.GetArchiveQueryEngine()
if archiveQueryEngine != nil {
logs, total, err = archiveQueryEngine.QueryLogs(filter, pageParams)
if err == nil {
// 归档查询成功,直接返回
totalPages := int64(0)
if limit > 0 {
totalPages = (total + int64(limit) - 1) / int64(limit)
}
response := map[string]interface{}{
"logs": logs,
"total": total,
"page": pageNum,
"limit": limit,
"totalPages": totalPages,
}
// 存入缓存(只缓存第一页)
if s.cacheEnabled && pageNum == 1 {
s.queryCache.Set(cacheKey, response)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
return
}
logger.Error("归档查询失败,降级到普通查询", "error", err)
}
// 使用原有查询方法
dnsLogs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
// 转换为 log.QueryLog 格式
logs = make([]log.QueryLog, len(dnsLogs))
for i, logItem := range dnsLogs {
// 将 Answers 转换为 JSON 字符串
answersJSON, _ := json.Marshal(logItem.Answers)
logs[i] = log.QueryLog{
Timestamp: logItem.Timestamp,
ClientIP: logItem.ClientIP,
Domain: logItem.Domain,
QueryType: logItem.QueryType,
ResponseTime: logItem.ResponseTime,
Result: logItem.Result,
BlockRule: logItem.BlockRule,
BlockType: logItem.BlockType,
FromCache: logItem.FromCache,
DNSSEC: logItem.DNSSEC,
EDNS: logItem.EDNS,
DNSServer: logItem.DNSServer,
DNSSECServer: logItem.DNSSECServer,
Answers: string(answersJSON),
ResponseCode: logItem.ResponseCode,
}
}
// 获取总记录数(用于计算总页数)
total = int64(s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType))
// 计算总页数
totalPages := int64(0)
if limit > 0 {
totalPages = (total + int64(limit) - 1) / int64(limit)
}
// 构建响应,包含分页信息
response := map[string]interface{}{
"logs": logs,
"total": total,
"page": pageNum,
"limit": limit,
"totalPages": totalPages,
}
// 存入缓存(只缓存第一页,因为用户最常查看第一页)
if s.cacheEnabled && pageNum == 1 {
s.queryCache.Set(cacheKey, response)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logs)
json.NewEncoder(w).Encode(response)
}
// handleLogsCount 处理日志总数请求
@@ -2060,53 +2189,88 @@ func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
searchTerm := r.URL.Query().Get("search")
queryType := r.URL.Query().Get("queryType")
// 构建缓存键(已禁用)
// cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
// 构建缓存键
cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
// 缓存未命中,获取带过滤条件的日志总数(已禁用缓存)
// 尝试从缓存获取
if s.cacheEnabled {
if cachedCount, found := s.queryCache.Get(cacheKey); found {
// 缓存命中,直接返回
if count, ok := cachedCount.(int); ok {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
return
}
}
}
// 缓存未命中,获取带过滤条件的日志总数
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
// 存入缓存(已禁用)
// if s.cacheEnabled {
// s.queryCache.Set(cacheKey, count)
// }
// 存入缓存
if s.cacheEnabled {
s.queryCache.Set(cacheKey, count)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
}
// handleDomainInfo 处理域名信息查询请求
func (s *Server) handleDomainInfo(w http.ResponseWriter, r *http.Request) {
// handleArchiveList 处理归档列表请求
func (s *Server) handleArchiveList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取归档管理器
archiveManager := s.dnsServer.GetArchiveManager()
if archiveManager == nil {
http.Error(w, "归档功能未启用", http.StatusNotFound)
return
}
// 获取归档列表
archives, err := archiveManager.GetArchiveList()
if err != nil {
logger.Error("获取归档列表失败", "error", err)
http.Error(w, fmt.Sprintf("获取归档列表失败:%v", err), http.StatusInternalServerError)
return
}
// 返回归档列表
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(archives)
}
// handleArchiveCleanup 处理归档清理请求
func (s *Server) handleArchiveCleanup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
// 获取归档管理器
archiveManager := s.dnsServer.GetArchiveManager()
if archiveManager == nil {
http.Error(w, "归档功能未启用", http.StatusNotFound)
return
}
if req.Domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
// 从域名信息数据库中查询
domainInfo, err := domain.GetDomainInfo(req.Domain)
// 执行清理
deleted, err := archiveManager.CleanupOldArchives()
if err != nil {
http.Error(w, "Failed to query domain info", http.StatusInternalServerError)
logger.Error("清理归档失败", "error", err)
http.Error(w, fmt.Sprintf("清理归档失败:%v", err), http.StatusInternalServerError)
return
}
// 返回域名信息
// 返回清理结果
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(domainInfo)
json.NewEncoder(w).Encode(map[string]interface{}{
"deleted": deleted,
"message": fmt.Sprintf("成功清理 %d 个归档文件", deleted),
})
}
// handleRestart 处理重启服务请求
@@ -2264,17 +2428,17 @@ func isService(obj map[string]interface{}) bool {
_, hasName := obj["name"]
_, hasUrl := obj["url"]
_, hasCategoryId := obj["categoryId"]
// 如果有 name 和 url,则认为是服务
if hasName && hasUrl {
return true
}
// 如果有 categoryId,也认为是服务
if hasCategoryId {
return true
}
return false
}
@@ -2798,15 +2962,15 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
continue // 跳过标题行
}
if len(record) >= 4 {
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatInfo := []string{threatType, threatName, riskLevel}
// 1. 完整域名匹配(所有类型都添加)
threatMap[domain] = threatInfo
// 2. 只有恶意网站类型才添加子域名匹配规则
// 类型判断:钓鱼网站、仿冒网站
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
@@ -2822,10 +2986,10 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
}
}
}
// 查询单个域名
var result string
// 1. 先检查完整匹配
if threat, exists := threatMap[domain]; exists {
result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain)
@@ -2840,13 +3004,13 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
// 不是完整的子域名部分,跳过
continue
}
result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain)
break
}
}
}
w.Header().Set("Content-Type", "application/json")
if result == "" {
// 未找到匹配的威胁信息
@@ -2915,15 +3079,15 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
continue // 跳过标题行
}
if len(record) >= 4 {
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatInfo := []string{threatType, threatName, riskLevel}
// 1. 完整域名匹配(所有类型都添加)
threatMap[domain] = threatInfo
// 2. 只有恶意网站类型才添加子域名匹配规则
// 类型判断:钓鱼网站、仿冒网站
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
@@ -2952,7 +3116,7 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
})
continue
}
// 2. 检查子域名匹配(遍历顶级域名规则)
matched := false
for threatDomain, threatInfo := range threatMap {
@@ -2964,7 +3128,7 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
// 去掉 threatDomain 的第一个字符(即去掉开头的点)
suffixToTrim := threatDomain[1:]
prefix := strings.TrimSuffix(domain, suffixToTrim)
// 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配)
if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) {
results = append(results, map[string]interface{}{
@@ -2977,7 +3141,7 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
}
}
}
if !matched {
results = append(results, map[string]interface{}{
"domain": domain,