package shield import ( "context" "crypto/sha256" "encoding/csv" "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" "os" "path/filepath" "strings" "sync" "time" "dns-server/config" "dns-server/logger" ) type DomainInfoManager struct { config *config.DomainInfoConfig domainInfoMap map[string]map[string]interface{} threatDatabase map[string]map[string]string trackerMap map[string]map[string]interface{} mutex sync.RWMutex updateCtx context.Context updateCancel context.CancelFunc updateRunning bool updateStatus map[string]string lastUpdateTime time.Time updateInterval time.Duration cacheDir string statsFile string } type DomainInfo struct { Domain string `json:"domain"` Category string `json:"category"` Type string `json:"type"` RiskLevel int `json:"riskLevel"` Info map[string]interface{} `json:"info,omitempty"` } type ThreatInfo struct { Type string `json:"type"` Name string `json:"name"` RiskLevel int `json:"riskLevel"` Domain string `json:"domain"` } type TrackerInfo struct { Domain string `json:"domain"` Category string `json:"category"` Tracker string `json:"tracker"` Info map[string]interface{} `json:"info,omitempty"` } func NewDomainInfoManager(config *config.DomainInfoConfig) *DomainInfoManager { ctx, cancel := context.WithCancel(context.Background()) manager := &DomainInfoManager{ config: config, domainInfoMap: make(map[string]map[string]interface{}), threatDatabase: make(map[string]map[string]string), trackerMap: make(map[string]map[string]interface{}), updateCtx: ctx, updateCancel: cancel, updateRunning: false, updateStatus: make(map[string]string), updateInterval: time.Duration(config.UpdateInterval) * time.Second, cacheDir: "data/domain_info_cache", statsFile: "data/domain_info_stats.json", } if err := os.MkdirAll(manager.cacheDir, 0755); err != nil { logger.Error("创建域名信息缓存目录失败", "error", err) } return manager } func (m *DomainInfoManager) Start() { logger.Info("启动域名信息管理器") // 异步加载缓存和初始数据 go func() { // 先加载统计信息 if err := m.loadStats(); err != nil { logger.Info("未找到域名信息统计文件,将使用默认值") } m.loadFromCache() // 启动自动更新循环 if m.config.EnableAutoUpdate { go m.autoUpdateLoop() } }() } func (m *DomainInfoManager) Stop() { logger.Info("停止域名信息管理器") m.updateCancel() } func (m *DomainInfoManager) autoUpdateLoop() { ticker := time.NewTicker(m.updateInterval) defer ticker.Stop() for { select { case <-m.updateCtx.Done(): return case <-ticker.C: logger.Info("开始自动更新域名信息") if err := m.LoadDomainInfo(); err != nil { logger.Error("自动更新域名信息失败", "error", err) } } } } func (m *DomainInfoManager) LoadDomainInfo() error { // 检查是否有更新正在运行 m.mutex.Lock() if m.updateRunning { m.mutex.Unlock() logger.Info("域名信息正在更新中,跳过") return nil } m.updateRunning = true m.updateStatus["overall"] = "running" m.mutex.Unlock() defer func() { m.mutex.Lock() m.updateRunning = false m.updateStatus["overall"] = "completed" m.mutex.Unlock() }() var wg sync.WaitGroup errChan := make(chan error, 3) for _, entry := range m.config.DomainInfoLists { if !entry.Enabled { continue } wg.Add(1) go func(entry config.DomainInfoEntry) { defer wg.Done() // 更新状态为开始 m.mutex.Lock() m.updateStatus[entry.Type] = "running" m.mutex.Unlock() var err error switch entry.Type { case "domain-info": err = m.fetchRemoteDomainInfo(entry.URL) case "threat-database": err = m.fetchThreatDatabase(entry.URL) case "tracker": err = m.fetchTrackerInfo(entry.URL) default: logger.Warn("未知的域名信息类型", "type", entry.Type) } // 更新状态为完成或失败 m.mutex.Lock() if err != nil { m.updateStatus[entry.Type] = "failed" errChan <- fmt.Errorf("更新 %s 失败: %v", entry.Name, err) } else { m.updateStatus[entry.Type] = "completed" } m.mutex.Unlock() }(entry) } wg.Wait() close(errChan) var errors []string for err := range errChan { errors = append(errors, err.Error()) } m.mutex.Lock() if len(errors) > 0 { m.lastUpdateTime = time.Now() m.saveToCache() m.mutex.Unlock() return fmt.Errorf("更新过程中有错误: %s", strings.Join(errors, "; ")) } m.lastUpdateTime = time.Now() m.saveToCache() m.mutex.Unlock() logger.Info("域名信息更新完成", "domainInfoCount", len(m.domainInfoMap), "threatCount", len(m.threatDatabase), "trackerCount", len(m.trackerMap)) return nil } func (m *DomainInfoManager) fetchRemoteDomainInfo(url string) error { data, err := m.fetchRemoteData(url) if err != nil { return err } var domainInfoData map[string]interface{} if err := json.Unmarshal(data, &domainInfoData); err != nil { return fmt.Errorf("解析域名信息 JSON 失败: %v", err) } m.mutex.Lock() defer m.mutex.Unlock() // 清空现有数据 m.domainInfoMap = make(map[string]map[string]interface{}) // 处理嵌套的domains结构 if domains, ok := domainInfoData["domains"].(map[string]interface{}); ok { for company, services := range domains { if servicesMap, ok := services.(map[string]interface{}); ok { for serviceName, serviceInfo := range servicesMap { if infoMap, ok := serviceInfo.(map[string]interface{}); ok { // 提取域名信息 if urlStr, ok := infoMap["url"].(string); ok { // 从URL中提取域名 domain := m.extractDomainFromURL(urlStr) if domain != "" { // 添加公司和服务信息 infoMap["company"] = company infoMap["service"] = serviceName m.domainInfoMap[domain] = infoMap } } } } } } } logger.Info("加载域名信息", "count", len(m.domainInfoMap), "url", url) return nil } func (m *DomainInfoManager) fetchThreatDatabase(url string) error { data, err := m.fetchRemoteData(url) if err != nil { return err } reader := csv.NewReader(strings.NewReader(string(data))) records, err := reader.ReadAll() if err != nil { return fmt.Errorf("解析威胁数据库 CSV 失败: %v", err) } m.mutex.Lock() defer m.mutex.Unlock() for i, record := range records { if i == 0 || len(record) < 4 { continue } domain := strings.TrimSpace(record[3]) if domain == "" { continue } m.threatDatabase[domain] = map[string]string{ "type": strings.TrimSpace(record[0]), "name": strings.TrimSpace(record[1]), "riskLevel": strings.TrimSpace(record[2]), "domain": domain, } } logger.Info("加载威胁数据库", "count", len(m.threatDatabase), "url", url) return nil } func (m *DomainInfoManager) fetchTrackerInfo(url string) error { data, err := m.fetchRemoteData(url) if err != nil { return err } var trackerData map[string]interface{} if err := json.Unmarshal(data, &trackerData); err != nil { return fmt.Errorf("解析跟踪器信息 JSON 失败: %v", err) } m.mutex.Lock() defer m.mutex.Unlock() // 清空现有数据 m.trackerMap = make(map[string]map[string]interface{}) // 处理trackers字段 if trackers, ok := trackerData["trackers"].(map[string]interface{}); ok { for domain, info := range trackers { if infoMap, ok := info.(map[string]interface{}); ok { m.trackerMap[domain] = infoMap } } } // 处理trackerDomains字段 if trackerDomains, ok := trackerData["trackerDomains"].(map[string]interface{}); ok { for trackerName, domains := range trackerDomains { if domainsList, ok := domains.([]interface{}); ok { for _, domain := range domainsList { if domainStr, ok := domain.(string); ok { // 为每个域名创建跟踪器信息 if _, exists := m.trackerMap[domainStr]; !exists { m.trackerMap[domainStr] = map[string]interface{}{ "tracker": trackerName, "domain": domainStr, } } } } } } } logger.Info("加载跟踪器信息", "count", len(m.trackerMap), "url", url) return nil } func (m *DomainInfoManager) fetchRemoteData(url string) ([]byte, error) { cacheFile := m.getCacheFilePath(url) if m.shouldUpdateCache(cacheFile) { resp, err := http.Get(url) if err != nil { logger.Warn("从远程获取数据失败,尝试使用缓存", "url", url, "error", err) return m.loadFromCacheFile(cacheFile) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP 状态码: %d", resp.StatusCode) } data, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %v", err) } if err := ioutil.WriteFile(cacheFile, data, 0644); err != nil { logger.Warn("保存缓存文件失败", "error", err) } return data, nil } return m.loadFromCacheFile(cacheFile) } func (m *DomainInfoManager) getCacheFilePath(url string) string { hash := fmt.Sprintf("%x", sha256.Sum256([]byte(url))) return filepath.Join(m.cacheDir, hash+".cache") } func (m *DomainInfoManager) shouldUpdateCache(cacheFile string) bool { if _, err := os.Stat(cacheFile); os.IsNotExist(err) { return true } fileInfo, err := os.Stat(cacheFile) if err != nil { return true } return time.Since(fileInfo.ModTime()) > m.updateInterval } func (m *DomainInfoManager) loadFromCacheFile(cacheFile string) ([]byte, error) { data, err := ioutil.ReadFile(cacheFile) if err != nil { return nil, err } return data, nil } func (m *DomainInfoManager) saveToCache() { data := map[string]interface{}{ "domainInfoMap": m.domainInfoMap, "threatDatabase": m.threatDatabase, "trackerMap": m.trackerMap, "lastUpdateTime": m.lastUpdateTime, } jsonData, err := json.Marshal(data) if err != nil { logger.Error("序列化域名信息缓存失败", "error", err) return } cacheFile := filepath.Join(m.cacheDir, "all_domain_info.json") if err := ioutil.WriteFile(cacheFile, jsonData, 0644); err != nil { logger.Error("保存域名信息缓存失败", "error", err) } // 同时保存统计信息 m.saveStats() } func (m *DomainInfoManager) loadFromCache() { cacheFile := filepath.Join(m.cacheDir, "all_domain_info.json") data, err := ioutil.ReadFile(cacheFile) if err != nil { logger.Info("未找到域名信息缓存,将从远程加载") if err := m.LoadDomainInfo(); err != nil { logger.Error("初始加载域名信息失败", "error", err) } return } var cacheData map[string]interface{} if err := json.Unmarshal(data, &cacheData); err != nil { logger.Error("解析域名信息缓存失败", "error", err) return } if domainInfoMap, ok := cacheData["domainInfoMap"].(map[string]interface{}); ok { for k, v := range domainInfoMap { if vm, ok := v.(map[string]interface{}); ok { m.domainInfoMap[k] = vm } } } if threatDatabase, ok := cacheData["threatDatabase"].(map[string]interface{}); ok { for k, v := range threatDatabase { if vm, ok := v.(map[string]interface{}); ok { strMap := make(map[string]string) for kk, vv := range vm { if vs, ok := vv.(string); ok { strMap[kk] = vs } } m.threatDatabase[k] = strMap } } } if trackerMap, ok := cacheData["trackerMap"].(map[string]interface{}); ok { for k, v := range trackerMap { if vm, ok := v.(map[string]interface{}); ok { m.trackerMap[k] = vm } } } logger.Info("从缓存加载域名信息完成", "domainInfoCount", len(m.domainInfoMap), "threatCount", len(m.threatDatabase), "trackerCount", len(m.trackerMap)) // 从缓存加载后也保存统计信息 m.saveStats() } func (m *DomainInfoManager) GetDomainInfo(domain string) map[string]interface{} { m.mutex.RLock() defer m.mutex.RUnlock() if info, ok := m.domainInfoMap[domain]; ok { return info } return nil } func (m *DomainInfoManager) GetThreatInfo(domain string) map[string]string { m.mutex.RLock() defer m.mutex.RUnlock() if info, ok := m.threatDatabase[domain]; ok { return info } return nil } func (m *DomainInfoManager) GetTrackerInfo(domain string) map[string]interface{} { m.mutex.RLock() defer m.mutex.RUnlock() if info, ok := m.trackerMap[domain]; ok { return info } return nil } // extractDomainFromURL 从URL中提取域名 func (m *DomainInfoManager) extractDomainFromURL(urlStr string) string { // 解析URL parsedURL, err := url.Parse(urlStr) if err != nil { return "" } // 返回主机名 return parsedURL.Hostname() } func (m *DomainInfoManager) GetAllDomainInfo() map[string]interface{} { // 获取当前统计(只读锁) m.mutex.RLock() domainInfoCount := len(m.domainInfoMap) threatCount := len(m.threatDatabase) trackerCount := len(m.trackerMap) lastUpdateTime := m.lastUpdateTime m.mutex.RUnlock() lastUpdateTimeStr := "从未更新" if !lastUpdateTime.IsZero() { lastUpdateTimeStr = lastUpdateTime.Format(time.RFC3339) } // 直接返回配置中的列表,并添加统计信息 var lists []config.DomainInfoEntry if m.config != nil { lists = make([]config.DomainInfoEntry, len(m.config.DomainInfoLists)) copy(lists, m.config.DomainInfoLists) // 为每个列表添加规则数量 for i := range lists { switch lists[i].Type { case "domain-info": lists[i].RuleCount = domainInfoCount case "threat-database": lists[i].RuleCount = threatCount case "tracker": lists[i].RuleCount = trackerCount } if !lastUpdateTime.IsZero() { lists[i].LastUpdateTime = lastUpdateTimeStr } } } else { lists = []config.DomainInfoEntry{} } return map[string]interface{}{ "lists": lists, "domainInfoCount": domainInfoCount, "threatCount": threatCount, "trackerCount": trackerCount, "lastUpdateTime": lastUpdateTimeStr, } } func (m *DomainInfoManager) getRuleCount(entryType string) int { switch entryType { case "domain-info": return len(m.domainInfoMap) case "threat-database": return len(m.threatDatabase) case "tracker": return len(m.trackerMap) } return 0 } func (m *DomainInfoManager) UpdateDomainInfoList(entryType string) error { for _, entry := range m.config.DomainInfoLists { if entry.Type != entryType { continue } m.mutex.Lock() switch entryType { case "domain-info": m.domainInfoMap = make(map[string]map[string]interface{}) case "threat-database": m.threatDatabase = make(map[string]map[string]string) case "tracker": m.trackerMap = make(map[string]map[string]interface{}) } m.mutex.Unlock() switch entryType { case "domain-info": return m.fetchRemoteDomainInfo(entry.URL) case "threat-database": return m.fetchThreatDatabase(entry.URL) case "tracker": return m.fetchTrackerInfo(entry.URL) } } return fmt.Errorf("未找到类型为 %s 的域名信息配置", entryType) } func (m *DomainInfoManager) AddDomainInfoList(entry config.DomainInfoEntry) error { m.config.DomainInfoLists = append(m.config.DomainInfoLists, entry) return nil } func (m *DomainInfoManager) RemoveDomainInfoList(entryType string) error { var newLists []config.DomainInfoEntry for _, entry := range m.config.DomainInfoLists { if entry.Type != entryType { newLists = append(newLists, entry) } } m.config.DomainInfoLists = newLists m.mutex.Lock() switch entryType { case "domain-info": m.domainInfoMap = make(map[string]map[string]interface{}) case "threat-database": m.threatDatabase = make(map[string]map[string]string) case "tracker": m.trackerMap = make(map[string]map[string]interface{}) } m.mutex.Unlock() m.saveToCache() return nil } // GetUpdateStatus 获取更新状态 func (m *DomainInfoManager) GetUpdateStatus() map[string]string { m.mutex.RLock() defer m.mutex.RUnlock() // 复制状态以避免并发修改 status := make(map[string]string) for k, v := range m.updateStatus { status[k] = v } return status } // DomainInfoStats 域名信息统计 type DomainInfoStats struct { DomainInfoCount int `json:"domainInfoCount"` ThreatCount int `json:"threatCount"` TrackerCount int `json:"trackerCount"` LastUpdateTime string `json:"lastUpdateTime"` Lists []struct { Name string `json:"name"` Type string `json:"type"` URL string `json:"url"` Enabled bool `json:"enabled"` RuleCount int `json:"ruleCount"` LastUpdateTime string `json:"lastUpdateTime"` } `json:"lists"` } // saveStats 保存统计信息到文件 func (m *DomainInfoManager) saveStats() { m.mutex.RLock() defer m.mutex.RUnlock() stats := DomainInfoStats{ DomainInfoCount: len(m.domainInfoMap), ThreatCount: len(m.threatDatabase), TrackerCount: len(m.trackerMap), } if !m.lastUpdateTime.IsZero() { stats.LastUpdateTime = m.lastUpdateTime.Format(time.RFC3339) } for _, entry := range m.config.DomainInfoLists { listStat := struct { Name string `json:"name"` Type string `json:"type"` URL string `json:"url"` Enabled bool `json:"enabled"` RuleCount int `json:"ruleCount"` LastUpdateTime string `json:"lastUpdateTime"` }{ Name: entry.Name, Type: entry.Type, URL: entry.URL, Enabled: entry.Enabled, } switch entry.Type { case "domain-info": listStat.RuleCount = len(m.domainInfoMap) case "threat-database": listStat.RuleCount = len(m.threatDatabase) case "tracker": listStat.RuleCount = len(m.trackerMap) } if !m.lastUpdateTime.IsZero() { listStat.LastUpdateTime = m.lastUpdateTime.Format(time.RFC3339) } stats.Lists = append(stats.Lists, listStat) } jsonData, err := json.MarshalIndent(stats, "", " ") if err != nil { logger.Error("序列化域名信息统计失败", "error", err) return } if err := ioutil.WriteFile(m.statsFile, jsonData, 0644); err != nil { logger.Error("保存域名信息统计失败", "error", err) } } // loadStats 从文件加载统计信息 func (m *DomainInfoManager) loadStats() error { data, err := ioutil.ReadFile(m.statsFile) if err != nil { return err } var stats DomainInfoStats if err := json.Unmarshal(data, &stats); err != nil { return err } // 更新配置中的统计信息 for i := range m.config.DomainInfoLists { for _, listStat := range stats.Lists { if m.config.DomainInfoLists[i].Type == listStat.Type { m.config.DomainInfoLists[i].RuleCount = listStat.RuleCount m.config.DomainInfoLists[i].LastUpdateTime = listStat.LastUpdateTime break } } } return nil }