增加威胁域名审计
This commit is contained in:
+241
-77
@@ -14,15 +14,14 @@ import (
|
||||
|
||||
"dns-server/config"
|
||||
"dns-server/dns"
|
||||
"dns-server/domain"
|
||||
"dns-server/gfw"
|
||||
"dns-server/log"
|
||||
"dns-server/logger"
|
||||
"dns-server/shield"
|
||||
|
||||
"gopkg.in/ini.v1"
|
||||
"dns-server/threat"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// CacheEntry 缓存条目
|
||||
@@ -328,11 +327,11 @@ type Server struct {
|
||||
broadcastChan chan []byte
|
||||
|
||||
// 查询缓存相关字段
|
||||
queryCache *QueryCache // 查询结果缓存
|
||||
statsCache *StatsCache // 统计数据缓存
|
||||
cacheEnabled bool // 缓存是否启用
|
||||
cacheTTL time.Duration // 缓存过期时间
|
||||
cacheMaxSize int // 缓存最大条目数
|
||||
queryCache *QueryCache // 查询结果缓存
|
||||
statsCache *StatsCache // 统计数据缓存
|
||||
cacheEnabled bool // 缓存是否启用
|
||||
cacheTTL time.Duration // 缓存过期时间
|
||||
cacheMaxSize int // 缓存最大条目数
|
||||
}
|
||||
|
||||
// NewServer 创建HTTP服务器实例
|
||||
@@ -357,11 +356,11 @@ func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager
|
||||
sessions: make(map[string]time.Time),
|
||||
sessionTTL: 24 * time.Hour, // 会话有效期 24 小时
|
||||
// 查询缓存初始化
|
||||
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
|
||||
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
|
||||
cacheEnabled: true, // 默认启用缓存
|
||||
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
|
||||
cacheMaxSize: 100, // 默认最大 100 条
|
||||
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
|
||||
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
|
||||
cacheEnabled: true, // 默认启用缓存
|
||||
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
|
||||
cacheMaxSize: 100, // 默认最大 100 条
|
||||
}
|
||||
|
||||
// 启动广播协程
|
||||
@@ -440,8 +439,9 @@ func (s *Server) Start() error {
|
||||
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
|
||||
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
|
||||
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
|
||||
// 域名查询相关接口
|
||||
mux.HandleFunc("/api/domain/info", s.loginRequired(s.handleDomainInfo))
|
||||
// 归档管理接口
|
||||
mux.HandleFunc("/api/logs/archives", s.loginRequired(s.handleArchiveList))
|
||||
mux.HandleFunc("/api/logs/archive-cleanup", s.loginRequired(s.handleArchiveCleanup))
|
||||
// 域名信息列表接口
|
||||
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
|
||||
// 威胁查询接口
|
||||
@@ -1559,8 +1559,9 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
stats := s.dnsServer.GetStats()
|
||||
|
||||
// 使用服务器的实际启动时间计算准确的运行时间
|
||||
serverStartTime := s.dnsServer.GetStartTime()
|
||||
uptime := time.Since(serverStartTime)
|
||||
// 由于已移除旧日志系统,使用当前时间作为启动时间
|
||||
serverStartTime := time.Now()
|
||||
uptime := time.Duration(0)
|
||||
|
||||
// 构建包含所有真实服务器统计数据的响应
|
||||
status := map[string]interface{}{
|
||||
@@ -1999,6 +2000,30 @@ func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
|
||||
// 缓存未命中,获取最新统计数据
|
||||
logStats := s.dnsServer.GetQueryStats()
|
||||
|
||||
// 添加归档统计信息
|
||||
archiveManager := s.dnsServer.GetArchiveManager()
|
||||
if archiveManager != nil {
|
||||
archives, err := archiveManager.GetArchiveList()
|
||||
if err == nil {
|
||||
var archiveTotalRecords int64 = 0
|
||||
var archiveTotalCompressedSize int64 = 0
|
||||
for _, archive := range archives {
|
||||
archiveTotalRecords += archive.RecordCount
|
||||
archiveTotalCompressedSize += archive.CompressedSize
|
||||
}
|
||||
|
||||
logStats["archiveCount"] = len(archives)
|
||||
logStats["archiveTotalRecords"] = archiveTotalRecords
|
||||
logStats["archiveTotalCompressedSize"] = archiveTotalCompressedSize
|
||||
logStats["archiveTotalSize"] = archiveTotalCompressedSize // 压缩后的大小
|
||||
|
||||
// 如果有主库统计,计算总记录数
|
||||
if totalRecords, ok := logStats["totalQueries"].(int64); ok {
|
||||
logStats["grandTotalRecords"] = totalRecords + archiveTotalRecords
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 存入缓存
|
||||
if s.cacheEnabled {
|
||||
s.statsCache.Set(cacheKey, logStats)
|
||||
@@ -2016,8 +2041,8 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
limit := 100 // 默认返回 100 条日志
|
||||
offset := 0
|
||||
limit := 30 // 默认返回 30 条日志
|
||||
pageNum := 1 // 默认第 1 页
|
||||
sortField := r.URL.Query().Get("sort")
|
||||
sortDirection := r.URL.Query().Get("direction")
|
||||
resultFilter := r.URL.Query().Get("result")
|
||||
@@ -2028,24 +2053,128 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Sscanf(limitStr, "%d", &limit)
|
||||
}
|
||||
|
||||
// 支持 page 参数(优先)或 offset 参数
|
||||
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
|
||||
fmt.Sscanf(pageStr, "%d", &pageNum)
|
||||
}
|
||||
offset := (pageNum - 1) * limit
|
||||
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
fmt.Sscanf(offsetStr, "%d", &offset)
|
||||
}
|
||||
|
||||
// 构建缓存键,包含所有查询参数
|
||||
// 已禁用缓存,每次都从数据库获取最新数据
|
||||
// cacheKey := fmt.Sprintf("logs_%d_%d_%s_%s_%s_%s_%s", limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||||
// 构建缓存键,包含所有查询参数(排除 offset,因为分页数据可以复用)
|
||||
cacheKey := fmt.Sprintf("logs_query_%s_%s_%s_%s_%s_%d", sortField, sortDirection, resultFilter, searchTerm, queryType, limit)
|
||||
|
||||
// 缓存未命中,获取日志数据(已禁用缓存)
|
||||
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||||
// 尝试从缓存获取
|
||||
if s.cacheEnabled {
|
||||
if cachedLogs, found := s.queryCache.Get(cacheKey); found {
|
||||
// 缓存命中,直接返回
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(cachedLogs)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 存入缓存(已禁用)
|
||||
// if s.cacheEnabled {
|
||||
// s.queryCache.Set(cacheKey, logs)
|
||||
// }
|
||||
// 构建过滤条件和分页参数
|
||||
filter := log.LogFilter{
|
||||
Result: resultFilter,
|
||||
SearchTerm: searchTerm,
|
||||
QueryType: queryType,
|
||||
}
|
||||
pageParams := log.PageParams{
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
SortField: sortField,
|
||||
SortDirection: sortDirection,
|
||||
}
|
||||
|
||||
var logs []log.QueryLog
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
// 优先使用归档查询引擎(如果可用)
|
||||
archiveQueryEngine := s.dnsServer.GetArchiveQueryEngine()
|
||||
if archiveQueryEngine != nil {
|
||||
logs, total, err = archiveQueryEngine.QueryLogs(filter, pageParams)
|
||||
if err == nil {
|
||||
// 归档查询成功,直接返回
|
||||
totalPages := int64(0)
|
||||
if limit > 0 {
|
||||
totalPages = (total + int64(limit) - 1) / int64(limit)
|
||||
}
|
||||
response := map[string]interface{}{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": pageNum,
|
||||
"limit": limit,
|
||||
"totalPages": totalPages,
|
||||
}
|
||||
|
||||
// 存入缓存(只缓存第一页)
|
||||
if s.cacheEnabled && pageNum == 1 {
|
||||
s.queryCache.Set(cacheKey, response)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
return
|
||||
}
|
||||
logger.Error("归档查询失败,降级到普通查询", "error", err)
|
||||
}
|
||||
|
||||
// 使用原有查询方法
|
||||
dnsLogs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||||
|
||||
// 转换为 log.QueryLog 格式
|
||||
logs = make([]log.QueryLog, len(dnsLogs))
|
||||
for i, logItem := range dnsLogs {
|
||||
// 将 Answers 转换为 JSON 字符串
|
||||
answersJSON, _ := json.Marshal(logItem.Answers)
|
||||
|
||||
logs[i] = log.QueryLog{
|
||||
Timestamp: logItem.Timestamp,
|
||||
ClientIP: logItem.ClientIP,
|
||||
Domain: logItem.Domain,
|
||||
QueryType: logItem.QueryType,
|
||||
ResponseTime: logItem.ResponseTime,
|
||||
Result: logItem.Result,
|
||||
BlockRule: logItem.BlockRule,
|
||||
BlockType: logItem.BlockType,
|
||||
FromCache: logItem.FromCache,
|
||||
DNSSEC: logItem.DNSSEC,
|
||||
EDNS: logItem.EDNS,
|
||||
DNSServer: logItem.DNSServer,
|
||||
DNSSECServer: logItem.DNSSECServer,
|
||||
Answers: string(answersJSON),
|
||||
ResponseCode: logItem.ResponseCode,
|
||||
}
|
||||
}
|
||||
// 获取总记录数(用于计算总页数)
|
||||
total = int64(s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType))
|
||||
|
||||
// 计算总页数
|
||||
totalPages := int64(0)
|
||||
if limit > 0 {
|
||||
totalPages = (total + int64(limit) - 1) / int64(limit)
|
||||
}
|
||||
|
||||
// 构建响应,包含分页信息
|
||||
response := map[string]interface{}{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": pageNum,
|
||||
"limit": limit,
|
||||
"totalPages": totalPages,
|
||||
}
|
||||
|
||||
// 存入缓存(只缓存第一页,因为用户最常查看第一页)
|
||||
if s.cacheEnabled && pageNum == 1 {
|
||||
s.queryCache.Set(cacheKey, response)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(logs)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleLogsCount 处理日志总数请求
|
||||
@@ -2060,53 +2189,88 @@ func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
|
||||
searchTerm := r.URL.Query().Get("search")
|
||||
queryType := r.URL.Query().Get("queryType")
|
||||
|
||||
// 构建缓存键(已禁用)
|
||||
// cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
|
||||
// 构建缓存键
|
||||
cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
|
||||
|
||||
// 缓存未命中,获取带过滤条件的日志总数(已禁用缓存)
|
||||
// 尝试从缓存获取
|
||||
if s.cacheEnabled {
|
||||
if cachedCount, found := s.queryCache.Get(cacheKey); found {
|
||||
// 缓存命中,直接返回
|
||||
if count, ok := cachedCount.(int); ok {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]int{"count": count})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,获取带过滤条件的日志总数
|
||||
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
|
||||
|
||||
// 存入缓存(已禁用)
|
||||
// if s.cacheEnabled {
|
||||
// s.queryCache.Set(cacheKey, count)
|
||||
// }
|
||||
// 存入缓存
|
||||
if s.cacheEnabled {
|
||||
s.queryCache.Set(cacheKey, count)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]int{"count": count})
|
||||
}
|
||||
|
||||
// handleDomainInfo 处理域名信息查询请求
|
||||
func (s *Server) handleDomainInfo(w http.ResponseWriter, r *http.Request) {
|
||||
// handleArchiveList 处理归档列表请求
|
||||
func (s *Server) handleArchiveList(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取归档管理器
|
||||
archiveManager := s.dnsServer.GetArchiveManager()
|
||||
if archiveManager == nil {
|
||||
http.Error(w, "归档功能未启用", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取归档列表
|
||||
archives, err := archiveManager.GetArchiveList()
|
||||
if err != nil {
|
||||
logger.Error("获取归档列表失败", "error", err)
|
||||
http.Error(w, fmt.Sprintf("获取归档列表失败:%v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回归档列表
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(archives)
|
||||
}
|
||||
|
||||
// handleArchiveCleanup 处理归档清理请求
|
||||
func (s *Server) handleArchiveCleanup(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
// 获取归档管理器
|
||||
archiveManager := s.dnsServer.GetArchiveManager()
|
||||
if archiveManager == nil {
|
||||
http.Error(w, "归档功能未启用", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Domain == "" {
|
||||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 从域名信息数据库中查询
|
||||
domainInfo, err := domain.GetDomainInfo(req.Domain)
|
||||
// 执行清理
|
||||
deleted, err := archiveManager.CleanupOldArchives()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to query domain info", http.StatusInternalServerError)
|
||||
logger.Error("清理归档失败", "error", err)
|
||||
http.Error(w, fmt.Sprintf("清理归档失败:%v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回域名信息
|
||||
// 返回清理结果
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(domainInfo)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"deleted": deleted,
|
||||
"message": fmt.Sprintf("成功清理 %d 个归档文件", deleted),
|
||||
})
|
||||
}
|
||||
|
||||
// handleRestart 处理重启服务请求
|
||||
@@ -2264,17 +2428,17 @@ func isService(obj map[string]interface{}) bool {
|
||||
_, hasName := obj["name"]
|
||||
_, hasUrl := obj["url"]
|
||||
_, hasCategoryId := obj["categoryId"]
|
||||
|
||||
|
||||
// 如果有 name 和 url,则认为是服务
|
||||
if hasName && hasUrl {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
// 如果有 categoryId,也认为是服务
|
||||
if hasCategoryId {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2798,15 +2962,15 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
|
||||
continue // 跳过标题行
|
||||
}
|
||||
if len(record) >= 4 {
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatInfo := []string{threatType, threatName, riskLevel}
|
||||
|
||||
|
||||
// 1. 完整域名匹配(所有类型都添加)
|
||||
threatMap[domain] = threatInfo
|
||||
|
||||
|
||||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||||
// 类型判断:钓鱼网站、仿冒网站
|
||||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||||
@@ -2822,10 +2986,10 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 查询单个域名
|
||||
var result string
|
||||
|
||||
|
||||
// 1. 先检查完整匹配
|
||||
if threat, exists := threatMap[domain]; exists {
|
||||
result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain)
|
||||
@@ -2840,13 +3004,13 @@ func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
|
||||
// 不是完整的子域名部分,跳过
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if result == "" {
|
||||
// 未找到匹配的威胁信息
|
||||
@@ -2915,15 +3079,15 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
|
||||
continue // 跳过标题行
|
||||
}
|
||||
if len(record) >= 4 {
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatType := record[0] // 第一列:类型
|
||||
threatName := record[1] // 第二列:名称
|
||||
riskLevel := record[2] // 第三列:风险等级
|
||||
domain := record[3] // 第四列:域名
|
||||
threatInfo := []string{threatType, threatName, riskLevel}
|
||||
|
||||
|
||||
// 1. 完整域名匹配(所有类型都添加)
|
||||
threatMap[domain] = threatInfo
|
||||
|
||||
|
||||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||||
// 类型判断:钓鱼网站、仿冒网站
|
||||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||||
@@ -2952,7 +3116,7 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// 2. 检查子域名匹配(遍历顶级域名规则)
|
||||
matched := false
|
||||
for threatDomain, threatInfo := range threatMap {
|
||||
@@ -2964,7 +3128,7 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
|
||||
// 去掉 threatDomain 的第一个字符(即去掉开头的点)
|
||||
suffixToTrim := threatDomain[1:]
|
||||
prefix := strings.TrimSuffix(domain, suffixToTrim)
|
||||
|
||||
|
||||
// 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配)
|
||||
if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) {
|
||||
results = append(results, map[string]interface{}{
|
||||
@@ -2977,7 +3141,7 @@ func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !matched {
|
||||
results = append(results, map[string]interface{}{
|
||||
"domain": domain,
|
||||
|
||||
Reference in New Issue
Block a user