Files
dns-server/shield/domain_info_manager.go
T
Alex Yang f9e2e5a6bc update
2026-04-12 21:40:22 +08:00

757 lines
18 KiB
Go

package shield
import (
"context"
"crypto/sha256"
"encoding/csv"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
)
type DomainInfoManager struct {
config *config.DomainInfoConfig
domainInfoMap map[string]map[string]interface{}
threatDatabase map[string]map[string]string
trackerMap map[string]map[string]interface{}
mutex sync.RWMutex
updateCtx context.Context
updateCancel context.CancelFunc
updateRunning bool
updateStatus map[string]string
lastUpdateTime time.Time
updateInterval time.Duration
cacheDir string
statsFile string
}
type DomainInfo struct {
Domain string `json:"domain"`
Category string `json:"category"`
Type string `json:"type"`
RiskLevel int `json:"riskLevel"`
Info map[string]interface{} `json:"info,omitempty"`
}
type ThreatInfo struct {
Type string `json:"type"`
Name string `json:"name"`
RiskLevel int `json:"riskLevel"`
Domain string `json:"domain"`
}
type TrackerInfo struct {
Domain string `json:"domain"`
Category string `json:"category"`
Tracker string `json:"tracker"`
Info map[string]interface{} `json:"info,omitempty"`
}
func NewDomainInfoManager(config *config.DomainInfoConfig) *DomainInfoManager {
ctx, cancel := context.WithCancel(context.Background())
manager := &DomainInfoManager{
config: config,
domainInfoMap: make(map[string]map[string]interface{}),
threatDatabase: make(map[string]map[string]string),
trackerMap: make(map[string]map[string]interface{}),
updateCtx: ctx,
updateCancel: cancel,
updateRunning: false,
updateStatus: make(map[string]string),
updateInterval: time.Duration(config.UpdateInterval) * time.Second,
cacheDir: "data/domain_info_cache",
statsFile: "data/domain_info_stats.json",
}
if err := os.MkdirAll(manager.cacheDir, 0755); err != nil {
logger.Error("创建域名信息缓存目录失败", "error", err)
}
return manager
}
func (m *DomainInfoManager) Start() {
logger.Info("启动域名信息管理器")
// 异步加载缓存和初始数据
go func() {
// 先加载统计信息
if err := m.loadStats(); err != nil {
logger.Info("未找到域名信息统计文件,将使用默认值")
}
m.loadFromCache()
// 启动自动更新循环
if m.config.EnableAutoUpdate {
go m.autoUpdateLoop()
}
}()
}
func (m *DomainInfoManager) Stop() {
logger.Info("停止域名信息管理器")
m.updateCancel()
}
func (m *DomainInfoManager) autoUpdateLoop() {
ticker := time.NewTicker(m.updateInterval)
defer ticker.Stop()
for {
select {
case <-m.updateCtx.Done():
return
case <-ticker.C:
logger.Info("开始自动更新域名信息")
if err := m.LoadDomainInfo(); err != nil {
logger.Error("自动更新域名信息失败", "error", err)
}
}
}
}
func (m *DomainInfoManager) LoadDomainInfo() error {
// 检查是否有更新正在运行
m.mutex.Lock()
if m.updateRunning {
m.mutex.Unlock()
logger.Info("域名信息正在更新中,跳过")
return nil
}
m.updateRunning = true
m.updateStatus["overall"] = "running"
m.mutex.Unlock()
defer func() {
m.mutex.Lock()
m.updateRunning = false
m.updateStatus["overall"] = "completed"
m.mutex.Unlock()
}()
var wg sync.WaitGroup
errChan := make(chan error, 3)
for _, entry := range m.config.DomainInfoLists {
if !entry.Enabled {
continue
}
wg.Add(1)
go func(entry config.DomainInfoEntry) {
defer wg.Done()
// 更新状态为开始
m.mutex.Lock()
m.updateStatus[entry.Type] = "running"
m.mutex.Unlock()
var err error
switch entry.Type {
case "domain-info":
err = m.fetchRemoteDomainInfo(entry.URL)
case "threat-database":
err = m.fetchThreatDatabase(entry.URL)
case "tracker":
err = m.fetchTrackerInfo(entry.URL)
default:
logger.Warn("未知的域名信息类型", "type", entry.Type)
}
// 更新状态为完成或失败
m.mutex.Lock()
if err != nil {
m.updateStatus[entry.Type] = "failed"
errChan <- fmt.Errorf("更新 %s 失败: %v", entry.Name, err)
} else {
m.updateStatus[entry.Type] = "completed"
}
m.mutex.Unlock()
}(entry)
}
wg.Wait()
close(errChan)
var errors []string
for err := range errChan {
errors = append(errors, err.Error())
}
m.mutex.Lock()
if len(errors) > 0 {
m.lastUpdateTime = time.Now()
m.saveToCache()
m.mutex.Unlock()
return fmt.Errorf("更新过程中有错误: %s", strings.Join(errors, "; "))
}
m.lastUpdateTime = time.Now()
m.saveToCache()
m.mutex.Unlock()
logger.Info("域名信息更新完成",
"domainInfoCount", len(m.domainInfoMap),
"threatCount", len(m.threatDatabase),
"trackerCount", len(m.trackerMap))
return nil
}
func (m *DomainInfoManager) fetchRemoteDomainInfo(url string) error {
data, err := m.fetchRemoteData(url)
if err != nil {
return err
}
var domainInfoData map[string]interface{}
if err := json.Unmarshal(data, &domainInfoData); err != nil {
return fmt.Errorf("解析域名信息 JSON 失败: %v", err)
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 清空现有数据
m.domainInfoMap = make(map[string]map[string]interface{})
// 处理嵌套的domains结构
if domains, ok := domainInfoData["domains"].(map[string]interface{}); ok {
for company, services := range domains {
if servicesMap, ok := services.(map[string]interface{}); ok {
for serviceName, serviceInfo := range servicesMap {
if infoMap, ok := serviceInfo.(map[string]interface{}); ok {
// 提取域名信息
if urlStr, ok := infoMap["url"].(string); ok {
// 从URL中提取域名
domain := m.extractDomainFromURL(urlStr)
if domain != "" {
// 添加公司和服务信息
infoMap["company"] = company
infoMap["service"] = serviceName
m.domainInfoMap[domain] = infoMap
}
}
}
}
}
}
}
logger.Info("加载域名信息", "count", len(m.domainInfoMap), "url", url)
return nil
}
func (m *DomainInfoManager) fetchThreatDatabase(url string) error {
data, err := m.fetchRemoteData(url)
if err != nil {
return err
}
reader := csv.NewReader(strings.NewReader(string(data)))
records, err := reader.ReadAll()
if err != nil {
return fmt.Errorf("解析威胁数据库 CSV 失败: %v", err)
}
m.mutex.Lock()
defer m.mutex.Unlock()
for i, record := range records {
if i == 0 || len(record) < 4 {
continue
}
domain := strings.TrimSpace(record[3])
if domain == "" {
continue
}
m.threatDatabase[domain] = map[string]string{
"type": strings.TrimSpace(record[0]),
"name": strings.TrimSpace(record[1]),
"riskLevel": strings.TrimSpace(record[2]),
"domain": domain,
}
}
logger.Info("加载威胁数据库", "count", len(m.threatDatabase), "url", url)
return nil
}
func (m *DomainInfoManager) fetchTrackerInfo(url string) error {
data, err := m.fetchRemoteData(url)
if err != nil {
return err
}
var trackerData map[string]interface{}
if err := json.Unmarshal(data, &trackerData); err != nil {
return fmt.Errorf("解析跟踪器信息 JSON 失败: %v", err)
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 清空现有数据
m.trackerMap = make(map[string]map[string]interface{})
// 处理trackers字段
if trackers, ok := trackerData["trackers"].(map[string]interface{}); ok {
for domain, info := range trackers {
if infoMap, ok := info.(map[string]interface{}); ok {
m.trackerMap[domain] = infoMap
}
}
}
// 处理trackerDomains字段
if trackerDomains, ok := trackerData["trackerDomains"].(map[string]interface{}); ok {
for trackerName, domains := range trackerDomains {
if domainsList, ok := domains.([]interface{}); ok {
for _, domain := range domainsList {
if domainStr, ok := domain.(string); ok {
// 为每个域名创建跟踪器信息
if _, exists := m.trackerMap[domainStr]; !exists {
m.trackerMap[domainStr] = map[string]interface{}{
"tracker": trackerName,
"domain": domainStr,
}
}
}
}
}
}
}
logger.Info("加载跟踪器信息", "count", len(m.trackerMap), "url", url)
return nil
}
func (m *DomainInfoManager) fetchRemoteData(url string) ([]byte, error) {
cacheFile := m.getCacheFilePath(url)
if m.shouldUpdateCache(cacheFile) {
resp, err := http.Get(url)
if err != nil {
logger.Warn("从远程获取数据失败,尝试使用缓存", "url", url, "error", err)
return m.loadFromCacheFile(cacheFile)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP 状态码: %d", resp.StatusCode)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if err := ioutil.WriteFile(cacheFile, data, 0644); err != nil {
logger.Warn("保存缓存文件失败", "error", err)
}
return data, nil
}
return m.loadFromCacheFile(cacheFile)
}
func (m *DomainInfoManager) getCacheFilePath(url string) string {
hash := fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
return filepath.Join(m.cacheDir, hash+".cache")
}
func (m *DomainInfoManager) shouldUpdateCache(cacheFile string) bool {
if _, err := os.Stat(cacheFile); os.IsNotExist(err) {
return true
}
fileInfo, err := os.Stat(cacheFile)
if err != nil {
return true
}
return time.Since(fileInfo.ModTime()) > m.updateInterval
}
func (m *DomainInfoManager) loadFromCacheFile(cacheFile string) ([]byte, error) {
data, err := ioutil.ReadFile(cacheFile)
if err != nil {
return nil, err
}
return data, nil
}
func (m *DomainInfoManager) saveToCache() {
data := map[string]interface{}{
"domainInfoMap": m.domainInfoMap,
"threatDatabase": m.threatDatabase,
"trackerMap": m.trackerMap,
"lastUpdateTime": m.lastUpdateTime,
}
jsonData, err := json.Marshal(data)
if err != nil {
logger.Error("序列化域名信息缓存失败", "error", err)
return
}
cacheFile := filepath.Join(m.cacheDir, "all_domain_info.json")
if err := ioutil.WriteFile(cacheFile, jsonData, 0644); err != nil {
logger.Error("保存域名信息缓存失败", "error", err)
}
// 同时保存统计信息
m.saveStats()
}
func (m *DomainInfoManager) loadFromCache() {
cacheFile := filepath.Join(m.cacheDir, "all_domain_info.json")
data, err := ioutil.ReadFile(cacheFile)
if err != nil {
logger.Info("未找到域名信息缓存,将从远程加载")
if err := m.LoadDomainInfo(); err != nil {
logger.Error("初始加载域名信息失败", "error", err)
}
return
}
var cacheData map[string]interface{}
if err := json.Unmarshal(data, &cacheData); err != nil {
logger.Error("解析域名信息缓存失败", "error", err)
return
}
if domainInfoMap, ok := cacheData["domainInfoMap"].(map[string]interface{}); ok {
for k, v := range domainInfoMap {
if vm, ok := v.(map[string]interface{}); ok {
m.domainInfoMap[k] = vm
}
}
}
if threatDatabase, ok := cacheData["threatDatabase"].(map[string]interface{}); ok {
for k, v := range threatDatabase {
if vm, ok := v.(map[string]interface{}); ok {
strMap := make(map[string]string)
for kk, vv := range vm {
if vs, ok := vv.(string); ok {
strMap[kk] = vs
}
}
m.threatDatabase[k] = strMap
}
}
}
if trackerMap, ok := cacheData["trackerMap"].(map[string]interface{}); ok {
for k, v := range trackerMap {
if vm, ok := v.(map[string]interface{}); ok {
m.trackerMap[k] = vm
}
}
}
logger.Info("从缓存加载域名信息完成",
"domainInfoCount", len(m.domainInfoMap),
"threatCount", len(m.threatDatabase),
"trackerCount", len(m.trackerMap))
// 从缓存加载后也保存统计信息
m.saveStats()
}
func (m *DomainInfoManager) GetDomainInfo(domain string) map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
if info, ok := m.domainInfoMap[domain]; ok {
return info
}
return nil
}
func (m *DomainInfoManager) GetThreatInfo(domain string) map[string]string {
m.mutex.RLock()
defer m.mutex.RUnlock()
if info, ok := m.threatDatabase[domain]; ok {
return info
}
return nil
}
func (m *DomainInfoManager) GetTrackerInfo(domain string) map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
if info, ok := m.trackerMap[domain]; ok {
return info
}
return nil
}
// extractDomainFromURL 从URL中提取域名
func (m *DomainInfoManager) extractDomainFromURL(urlStr string) string {
// 解析URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return ""
}
// 返回主机名
return parsedURL.Hostname()
}
func (m *DomainInfoManager) GetAllDomainInfo() map[string]interface{} {
// 获取当前统计(只读锁)
m.mutex.RLock()
domainInfoCount := len(m.domainInfoMap)
threatCount := len(m.threatDatabase)
trackerCount := len(m.trackerMap)
lastUpdateTime := m.lastUpdateTime
m.mutex.RUnlock()
lastUpdateTimeStr := "从未更新"
if !lastUpdateTime.IsZero() {
lastUpdateTimeStr = lastUpdateTime.Format(time.RFC3339)
}
// 直接返回配置中的列表,并添加统计信息
var lists []config.DomainInfoEntry
if m.config != nil {
lists = make([]config.DomainInfoEntry, len(m.config.DomainInfoLists))
copy(lists, m.config.DomainInfoLists)
// 为每个列表添加规则数量
for i := range lists {
switch lists[i].Type {
case "domain-info":
lists[i].RuleCount = domainInfoCount
case "threat-database":
lists[i].RuleCount = threatCount
case "tracker":
lists[i].RuleCount = trackerCount
}
if !lastUpdateTime.IsZero() {
lists[i].LastUpdateTime = lastUpdateTimeStr
}
}
} else {
lists = []config.DomainInfoEntry{}
}
return map[string]interface{}{
"lists": lists,
"domainInfoCount": domainInfoCount,
"threatCount": threatCount,
"trackerCount": trackerCount,
"lastUpdateTime": lastUpdateTimeStr,
}
}
func (m *DomainInfoManager) getRuleCount(entryType string) int {
switch entryType {
case "domain-info":
return len(m.domainInfoMap)
case "threat-database":
return len(m.threatDatabase)
case "tracker":
return len(m.trackerMap)
}
return 0
}
func (m *DomainInfoManager) UpdateDomainInfoList(entryType string) error {
for _, entry := range m.config.DomainInfoLists {
if entry.Type != entryType {
continue
}
m.mutex.Lock()
switch entryType {
case "domain-info":
m.domainInfoMap = make(map[string]map[string]interface{})
case "threat-database":
m.threatDatabase = make(map[string]map[string]string)
case "tracker":
m.trackerMap = make(map[string]map[string]interface{})
}
m.mutex.Unlock()
switch entryType {
case "domain-info":
return m.fetchRemoteDomainInfo(entry.URL)
case "threat-database":
return m.fetchThreatDatabase(entry.URL)
case "tracker":
return m.fetchTrackerInfo(entry.URL)
}
}
return fmt.Errorf("未找到类型为 %s 的域名信息配置", entryType)
}
func (m *DomainInfoManager) AddDomainInfoList(entry config.DomainInfoEntry) error {
m.config.DomainInfoLists = append(m.config.DomainInfoLists, entry)
return nil
}
func (m *DomainInfoManager) RemoveDomainInfoList(entryType string) error {
var newLists []config.DomainInfoEntry
for _, entry := range m.config.DomainInfoLists {
if entry.Type != entryType {
newLists = append(newLists, entry)
}
}
m.config.DomainInfoLists = newLists
m.mutex.Lock()
switch entryType {
case "domain-info":
m.domainInfoMap = make(map[string]map[string]interface{})
case "threat-database":
m.threatDatabase = make(map[string]map[string]string)
case "tracker":
m.trackerMap = make(map[string]map[string]interface{})
}
m.mutex.Unlock()
m.saveToCache()
return nil
}
// GetUpdateStatus 获取更新状态
func (m *DomainInfoManager) GetUpdateStatus() map[string]string {
m.mutex.RLock()
defer m.mutex.RUnlock()
// 复制状态以避免并发修改
status := make(map[string]string)
for k, v := range m.updateStatus {
status[k] = v
}
return status
}
// DomainInfoStats 域名信息统计
type DomainInfoStats struct {
DomainInfoCount int `json:"domainInfoCount"`
ThreatCount int `json:"threatCount"`
TrackerCount int `json:"trackerCount"`
LastUpdateTime string `json:"lastUpdateTime"`
Lists []struct {
Name string `json:"name"`
Type string `json:"type"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
RuleCount int `json:"ruleCount"`
LastUpdateTime string `json:"lastUpdateTime"`
} `json:"lists"`
}
// saveStats 保存统计信息到文件
func (m *DomainInfoManager) saveStats() {
m.mutex.RLock()
defer m.mutex.RUnlock()
stats := DomainInfoStats{
DomainInfoCount: len(m.domainInfoMap),
ThreatCount: len(m.threatDatabase),
TrackerCount: len(m.trackerMap),
}
if !m.lastUpdateTime.IsZero() {
stats.LastUpdateTime = m.lastUpdateTime.Format(time.RFC3339)
}
for _, entry := range m.config.DomainInfoLists {
listStat := struct {
Name string `json:"name"`
Type string `json:"type"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
RuleCount int `json:"ruleCount"`
LastUpdateTime string `json:"lastUpdateTime"`
}{
Name: entry.Name,
Type: entry.Type,
URL: entry.URL,
Enabled: entry.Enabled,
}
switch entry.Type {
case "domain-info":
listStat.RuleCount = len(m.domainInfoMap)
case "threat-database":
listStat.RuleCount = len(m.threatDatabase)
case "tracker":
listStat.RuleCount = len(m.trackerMap)
}
if !m.lastUpdateTime.IsZero() {
listStat.LastUpdateTime = m.lastUpdateTime.Format(time.RFC3339)
}
stats.Lists = append(stats.Lists, listStat)
}
jsonData, err := json.MarshalIndent(stats, "", " ")
if err != nil {
logger.Error("序列化域名信息统计失败", "error", err)
return
}
if err := ioutil.WriteFile(m.statsFile, jsonData, 0644); err != nil {
logger.Error("保存域名信息统计失败", "error", err)
}
}
// loadStats 从文件加载统计信息
func (m *DomainInfoManager) loadStats() error {
data, err := ioutil.ReadFile(m.statsFile)
if err != nil {
return err
}
var stats DomainInfoStats
if err := json.Unmarshal(data, &stats); err != nil {
return err
}
// 更新配置中的统计信息
for i := range m.config.DomainInfoLists {
for _, listStat := range stats.Lists {
if m.config.DomainInfoLists[i].Type == listStat.Type {
m.config.DomainInfoLists[i].RuleCount = listStat.RuleCount
m.config.DomainInfoLists[i].LastUpdateTime = listStat.LastUpdateTime
break
}
}
}
return nil
}