312 lines
7.2 KiB
Go
312 lines
7.2 KiB
Go
package threat
|
|
|
|
import (
|
|
"encoding/csv"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fsnotify/fsnotify"
|
|
|
|
"dns-server/logger"
|
|
)
|
|
|
|
// ThreatInfo 威胁域名详细信息
|
|
type ThreatInfo struct {
|
|
Type string // 威胁类型(如:钓鱼网站、木马、仿冒软件等)
|
|
Name string // 威胁名称(如:Silver fox 团伙)
|
|
RiskLevel int // 风险等级(1-3,1=低,2=中,3=高)
|
|
Domain string // 域名
|
|
}
|
|
|
|
// ThreatDatabaseManager 威胁域名数据库管理器
|
|
type ThreatDatabaseManager struct {
|
|
databasePath string
|
|
threatData map[string]*ThreatInfo // 域名 -> 威胁信息
|
|
mutex sync.RWMutex
|
|
watcher *fsnotify.Watcher
|
|
watcherDone chan struct{}
|
|
watcherMutex sync.Mutex
|
|
isWatching bool
|
|
}
|
|
|
|
// NewThreatDatabaseManager 创建威胁域名数据库管理器
|
|
func NewThreatDatabaseManager(databasePath string) *ThreatDatabaseManager {
|
|
return &ThreatDatabaseManager{
|
|
databasePath: databasePath,
|
|
threatData: make(map[string]*ThreatInfo),
|
|
watcherDone: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// LoadDatabase 加载威胁域名数据库
|
|
func (m *ThreatDatabaseManager) LoadDatabase() error {
|
|
logger.Info("开始加载威胁域名数据库", "path", m.databasePath)
|
|
|
|
// 打开CSV文件
|
|
file, err := os.Open(m.databasePath)
|
|
if err != nil {
|
|
// 如果文件不存在,创建一个空文件
|
|
if os.IsNotExist(err) {
|
|
file, err = os.Create(m.databasePath)
|
|
if err != nil {
|
|
logger.Error("创建威胁域名数据库文件失败", "error", err)
|
|
return err
|
|
}
|
|
file.Close()
|
|
logger.Info("创建了空的威胁域名数据库文件")
|
|
return nil
|
|
}
|
|
logger.Error("打开威胁域名数据库文件失败", "error", err)
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
// 读取CSV文件
|
|
reader := csv.NewReader(file)
|
|
records, err := reader.ReadAll()
|
|
if err != nil {
|
|
logger.Error("读取威胁域名数据库文件失败", "error", err)
|
|
return err
|
|
}
|
|
|
|
// 加载威胁域名
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
count := 0
|
|
skipHeader := true
|
|
for _, record := range records {
|
|
// 跳过表头行
|
|
if skipHeader {
|
|
skipHeader = false
|
|
continue
|
|
}
|
|
|
|
if len(record) >= 4 { // 确保至少有4列
|
|
threatType := record[0] // 威胁类型
|
|
name := record[1] // 威胁名称
|
|
riskLevelStr := record[2] // 风险等级
|
|
domain := record[3] // 域名
|
|
|
|
// 解析风险等级
|
|
riskLevel, err := strconv.Atoi(riskLevelStr)
|
|
if err != nil {
|
|
riskLevel = 1 // 默认低风险
|
|
}
|
|
|
|
// 存储威胁信息
|
|
m.threatData[strings.ToLower(domain)] = &ThreatInfo{
|
|
Type: threatType,
|
|
Name: name,
|
|
RiskLevel: riskLevel,
|
|
Domain: domain,
|
|
}
|
|
count++
|
|
}
|
|
}
|
|
|
|
logger.Info("威胁域名数据库加载完成", "count", count)
|
|
return nil
|
|
}
|
|
|
|
// GetThreatInfo 获取域名的威胁信息
|
|
func (m *ThreatDatabaseManager) GetThreatInfo(domain string) *ThreatInfo {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
return m.threatData[strings.ToLower(domain)]
|
|
}
|
|
|
|
// IsThreatDomain 检查域名是否在威胁数据库中
|
|
func (m *ThreatDatabaseManager) IsThreatDomain(domain string) bool {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
_, exists := m.threatData[strings.ToLower(domain)]
|
|
return exists
|
|
}
|
|
|
|
// AddThreatDomain 添加威胁域名到数据库
|
|
func (m *ThreatDatabaseManager) AddThreatDomain(threatType, name string, riskLevel int, domain string) error {
|
|
m.mutex.Lock()
|
|
m.threatData[strings.ToLower(domain)] = &ThreatInfo{
|
|
Type: threatType,
|
|
Name: name,
|
|
RiskLevel: riskLevel,
|
|
Domain: domain,
|
|
}
|
|
m.mutex.Unlock()
|
|
|
|
// 保存到文件
|
|
return m.saveDatabase()
|
|
}
|
|
|
|
// RemoveThreatDomain 从数据库中移除威胁域名
|
|
func (m *ThreatDatabaseManager) RemoveThreatDomain(domain string) error {
|
|
m.mutex.Lock()
|
|
delete(m.threatData, strings.ToLower(domain))
|
|
m.mutex.Unlock()
|
|
|
|
// 保存到文件
|
|
return m.saveDatabase()
|
|
}
|
|
|
|
// GetAllThreatDomains 获取所有威胁域名信息
|
|
func (m *ThreatDatabaseManager) GetAllThreatDomains() []*ThreatInfo {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
threats := make([]*ThreatInfo, 0, len(m.threatData))
|
|
for _, info := range m.threatData {
|
|
threats = append(threats, info)
|
|
}
|
|
|
|
return threats
|
|
}
|
|
|
|
// saveDatabase 保存数据库到文件
|
|
func (m *ThreatDatabaseManager) saveDatabase() error {
|
|
// 打开CSV文件
|
|
file, err := os.Create(m.databasePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
|
|
// 写入CSV文件
|
|
writer := csv.NewWriter(file)
|
|
defer writer.Flush()
|
|
|
|
// 写入表头
|
|
if err := writer.Write([]string{"type", "name", "riskLevel", "domain"}); err != nil {
|
|
return err
|
|
}
|
|
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
for _, info := range m.threatData {
|
|
record := []string{
|
|
info.Type,
|
|
info.Name,
|
|
strconv.Itoa(info.RiskLevel),
|
|
info.Domain,
|
|
}
|
|
if err := writer.Write(record); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// StartWatching 开始监听数据库文件变化
|
|
func (m *ThreatDatabaseManager) StartWatching() error {
|
|
m.watcherMutex.Lock()
|
|
defer m.watcherMutex.Unlock()
|
|
|
|
if m.isWatching {
|
|
logger.Info("文件监听已经在运行中")
|
|
return nil
|
|
}
|
|
|
|
var err error
|
|
m.watcher, err = fsnotify.NewWatcher()
|
|
if err != nil {
|
|
logger.Error("创建文件监听器失败", "error", err)
|
|
return err
|
|
}
|
|
|
|
// 确保目录存在
|
|
dir := filepath.Dir(m.databasePath)
|
|
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
|
logger.Error("数据库文件目录不存在", "dir", dir, "error", err)
|
|
return err
|
|
}
|
|
|
|
// 添加目录监听
|
|
if err := m.watcher.Add(dir); err != nil {
|
|
logger.Error("添加目录监听失败", "dir", dir, "error", err)
|
|
m.watcher.Close()
|
|
return err
|
|
}
|
|
|
|
m.isWatching = true
|
|
logger.Info("开始监听威胁域名数据库文件", "path", m.databasePath)
|
|
|
|
// 启动监听协程
|
|
go m.watchFileChanges()
|
|
|
|
return nil
|
|
}
|
|
|
|
// StopWatching 停止监听数据库文件变化
|
|
func (m *ThreatDatabaseManager) StopWatching() {
|
|
m.watcherMutex.Lock()
|
|
defer m.watcherMutex.Unlock()
|
|
|
|
if !m.isWatching {
|
|
return
|
|
}
|
|
|
|
close(m.watcherDone)
|
|
if m.watcher != nil {
|
|
m.watcher.Close()
|
|
}
|
|
m.isWatching = false
|
|
logger.Info("停止监听威胁域名数据库文件")
|
|
}
|
|
|
|
// watchFileChanges 监听文件变化并重新加载数据库
|
|
func (m *ThreatDatabaseManager) watchFileChanges() {
|
|
var reloadTimer *time.Timer
|
|
defer func() {
|
|
if reloadTimer != nil {
|
|
reloadTimer.Stop()
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case event, ok := <-m.watcher.Events:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
// 检查是否是我们关注的文件
|
|
if filepath.Clean(event.Name) != filepath.Clean(m.databasePath) {
|
|
continue
|
|
}
|
|
|
|
logger.Debug("监听到文件变化", "event", event.Op.String(), "file", event.Name)
|
|
|
|
// 只关注文件写入和创建事件
|
|
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
|
// 使用防抖机制,避免频繁重新加载
|
|
if reloadTimer != nil {
|
|
reloadTimer.Stop()
|
|
}
|
|
reloadTimer = time.AfterFunc(500*time.Millisecond, func() {
|
|
logger.Info("检测到数据库文件变化,开始重新加载", "path", m.databasePath)
|
|
if err := m.LoadDatabase(); err != nil {
|
|
logger.Error("重新加载数据库失败", "error", err)
|
|
} else {
|
|
logger.Info("数据库重新加载成功")
|
|
}
|
|
})
|
|
}
|
|
|
|
case err, ok := <-m.watcher.Errors:
|
|
if !ok {
|
|
return
|
|
}
|
|
logger.Error("文件监听器错误", "error", err)
|
|
|
|
case <-m.watcherDone:
|
|
return
|
|
}
|
|
}
|
|
}
|