Files
2026-04-03 10:04:07 +08:00

312 lines
7.2 KiB
Go

package threat
import (
"encoding/csv"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"dns-server/logger"
)
// ThreatInfo 威胁域名详细信息
type ThreatInfo struct {
Type string // 威胁类型(如:钓鱼网站、木马、仿冒软件等)
Name string // 威胁名称(如:Silver fox 团伙)
RiskLevel int // 风险等级(1-3,1=低,2=中,3=高)
Domain string // 域名
}
// ThreatDatabaseManager 威胁域名数据库管理器
type ThreatDatabaseManager struct {
databasePath string
threatData map[string]*ThreatInfo // 域名 -> 威胁信息
mutex sync.RWMutex
watcher *fsnotify.Watcher
watcherDone chan struct{}
watcherMutex sync.Mutex
isWatching bool
}
// NewThreatDatabaseManager 创建威胁域名数据库管理器
func NewThreatDatabaseManager(databasePath string) *ThreatDatabaseManager {
return &ThreatDatabaseManager{
databasePath: databasePath,
threatData: make(map[string]*ThreatInfo),
watcherDone: make(chan struct{}),
}
}
// LoadDatabase 加载威胁域名数据库
func (m *ThreatDatabaseManager) LoadDatabase() error {
logger.Info("开始加载威胁域名数据库", "path", m.databasePath)
// 打开CSV文件
file, err := os.Open(m.databasePath)
if err != nil {
// 如果文件不存在,创建一个空文件
if os.IsNotExist(err) {
file, err = os.Create(m.databasePath)
if err != nil {
logger.Error("创建威胁域名数据库文件失败", "error", err)
return err
}
file.Close()
logger.Info("创建了空的威胁域名数据库文件")
return nil
}
logger.Error("打开威胁域名数据库文件失败", "error", err)
return err
}
defer file.Close()
// 读取CSV文件
reader := csv.NewReader(file)
records, err := reader.ReadAll()
if err != nil {
logger.Error("读取威胁域名数据库文件失败", "error", err)
return err
}
// 加载威胁域名
m.mutex.Lock()
defer m.mutex.Unlock()
count := 0
skipHeader := true
for _, record := range records {
// 跳过表头行
if skipHeader {
skipHeader = false
continue
}
if len(record) >= 4 { // 确保至少有4列
threatType := record[0] // 威胁类型
name := record[1] // 威胁名称
riskLevelStr := record[2] // 风险等级
domain := record[3] // 域名
// 解析风险等级
riskLevel, err := strconv.Atoi(riskLevelStr)
if err != nil {
riskLevel = 1 // 默认低风险
}
// 存储威胁信息
m.threatData[strings.ToLower(domain)] = &ThreatInfo{
Type: threatType,
Name: name,
RiskLevel: riskLevel,
Domain: domain,
}
count++
}
}
logger.Info("威胁域名数据库加载完成", "count", count)
return nil
}
// GetThreatInfo 获取域名的威胁信息
func (m *ThreatDatabaseManager) GetThreatInfo(domain string) *ThreatInfo {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.threatData[strings.ToLower(domain)]
}
// IsThreatDomain 检查域名是否在威胁数据库中
func (m *ThreatDatabaseManager) IsThreatDomain(domain string) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
_, exists := m.threatData[strings.ToLower(domain)]
return exists
}
// AddThreatDomain 添加威胁域名到数据库
func (m *ThreatDatabaseManager) AddThreatDomain(threatType, name string, riskLevel int, domain string) error {
m.mutex.Lock()
m.threatData[strings.ToLower(domain)] = &ThreatInfo{
Type: threatType,
Name: name,
RiskLevel: riskLevel,
Domain: domain,
}
m.mutex.Unlock()
// 保存到文件
return m.saveDatabase()
}
// RemoveThreatDomain 从数据库中移除威胁域名
func (m *ThreatDatabaseManager) RemoveThreatDomain(domain string) error {
m.mutex.Lock()
delete(m.threatData, strings.ToLower(domain))
m.mutex.Unlock()
// 保存到文件
return m.saveDatabase()
}
// GetAllThreatDomains 获取所有威胁域名信息
func (m *ThreatDatabaseManager) GetAllThreatDomains() []*ThreatInfo {
m.mutex.RLock()
defer m.mutex.RUnlock()
threats := make([]*ThreatInfo, 0, len(m.threatData))
for _, info := range m.threatData {
threats = append(threats, info)
}
return threats
}
// saveDatabase 保存数据库到文件
func (m *ThreatDatabaseManager) saveDatabase() error {
// 打开CSV文件
file, err := os.Create(m.databasePath)
if err != nil {
return err
}
defer file.Close()
// 写入CSV文件
writer := csv.NewWriter(file)
defer writer.Flush()
// 写入表头
if err := writer.Write([]string{"type", "name", "riskLevel", "domain"}); err != nil {
return err
}
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, info := range m.threatData {
record := []string{
info.Type,
info.Name,
strconv.Itoa(info.RiskLevel),
info.Domain,
}
if err := writer.Write(record); err != nil {
return err
}
}
return nil
}
// StartWatching 开始监听数据库文件变化
func (m *ThreatDatabaseManager) StartWatching() error {
m.watcherMutex.Lock()
defer m.watcherMutex.Unlock()
if m.isWatching {
logger.Info("文件监听已经在运行中")
return nil
}
var err error
m.watcher, err = fsnotify.NewWatcher()
if err != nil {
logger.Error("创建文件监听器失败", "error", err)
return err
}
// 确保目录存在
dir := filepath.Dir(m.databasePath)
if _, err := os.Stat(dir); os.IsNotExist(err) {
logger.Error("数据库文件目录不存在", "dir", dir, "error", err)
return err
}
// 添加目录监听
if err := m.watcher.Add(dir); err != nil {
logger.Error("添加目录监听失败", "dir", dir, "error", err)
m.watcher.Close()
return err
}
m.isWatching = true
logger.Info("开始监听威胁域名数据库文件", "path", m.databasePath)
// 启动监听协程
go m.watchFileChanges()
return nil
}
// StopWatching 停止监听数据库文件变化
func (m *ThreatDatabaseManager) StopWatching() {
m.watcherMutex.Lock()
defer m.watcherMutex.Unlock()
if !m.isWatching {
return
}
close(m.watcherDone)
if m.watcher != nil {
m.watcher.Close()
}
m.isWatching = false
logger.Info("停止监听威胁域名数据库文件")
}
// watchFileChanges 监听文件变化并重新加载数据库
func (m *ThreatDatabaseManager) watchFileChanges() {
var reloadTimer *time.Timer
defer func() {
if reloadTimer != nil {
reloadTimer.Stop()
}
}()
for {
select {
case event, ok := <-m.watcher.Events:
if !ok {
return
}
// 检查是否是我们关注的文件
if filepath.Clean(event.Name) != filepath.Clean(m.databasePath) {
continue
}
logger.Debug("监听到文件变化", "event", event.Op.String(), "file", event.Name)
// 只关注文件写入和创建事件
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
// 使用防抖机制,避免频繁重新加载
if reloadTimer != nil {
reloadTimer.Stop()
}
reloadTimer = time.AfterFunc(500*time.Millisecond, func() {
logger.Info("检测到数据库文件变化,开始重新加载", "path", m.databasePath)
if err := m.LoadDatabase(); err != nil {
logger.Error("重新加载数据库失败", "error", err)
} else {
logger.Info("数据库重新加载成功")
}
})
}
case err, ok := <-m.watcher.Errors:
if !ok {
return
}
logger.Error("文件监听器错误", "error", err)
case <-m.watcherDone:
return
}
}
}