package log import ( "database/sql" "encoding/json" "fmt" "os" "sync" "time" _ "github.com/mattn/go-sqlite3" ) // SQLiteStore SQLite 存储引擎 type SQLiteStore struct { db *sql.DB batchChan chan QueryLog batchSize int batchDelay time.Duration mu sync.Mutex closed bool } // NewSQLiteStore 创建 SQLite 存储 func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { // 打开数据库连接 db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("打开数据库失败:%w", err) } // 设置连接参数 db.SetMaxOpenConns(1) // SQLite 不支持高并发写入 db.SetMaxIdleConns(1) db.SetConnMaxLifetime(time.Hour) // 启用 WAL 模式 _, err = db.Exec("PRAGMA journal_mode=WAL") if err != nil { db.Close() return nil, fmt.Errorf("启用 WAL 模式失败:%w", err) } // 创建表结构 err = createTables(db) if err != nil { db.Close() return nil, fmt.Errorf("创建表结构失败:%w", err) } store := &SQLiteStore{ db: db, batchChan: make(chan QueryLog, 10000), batchSize: 100, batchDelay: 100 * time.Millisecond, } // 启动批量写入协程 go store.batchWriter() return store, nil } // createTables 创建数据库表 func createTables(db *sql.DB) error { schema := ` CREATE TABLE IF NOT EXISTS query_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp DATETIME NOT NULL, client_ip TEXT NOT NULL, domain TEXT NOT NULL, query_type TEXT NOT NULL, response_time INTEGER NOT NULL, result TEXT NOT NULL, block_rule TEXT, block_type TEXT, from_cache BOOLEAN DEFAULT FALSE, dnssec BOOLEAN DEFAULT FALSE, edns BOOLEAN DEFAULT FALSE, dns_server TEXT, dnssec_server TEXT, answers TEXT, response_code INTEGER, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_timestamp ON query_logs(timestamp DESC); CREATE INDEX IF NOT EXISTS idx_domain ON query_logs(domain); CREATE INDEX IF NOT EXISTS idx_client_ip ON query_logs(client_ip); CREATE INDEX IF NOT EXISTS idx_result ON query_logs(result); CREATE INDEX IF NOT EXISTS idx_query_type ON query_logs(query_type); ` _, err := db.Exec(schema) return err } // batchWriter 批量写入协程 func (s *SQLiteStore) batchWriter() { batch := make([]QueryLog, 0, s.batchSize) ticker := time.NewTicker(s.batchDelay) defer ticker.Stop() for { select { case log, ok := <-s.batchChan: if !ok { // 通道关闭,写入剩余日志 if len(batch) > 0 { s.writeBatch(batch) } return } batch = append(batch, log) // 达到批量大小时立即写入 if len(batch) >= s.batchSize { s.writeBatch(batch) batch = batch[:0] } case <-ticker.C: // 定时写入 if len(batch) > 0 { s.writeBatch(batch) batch = batch[:0] } } } } // writeBatch 批量写入日志 func (s *SQLiteStore) writeBatch(batch []QueryLog) { if len(batch) == 0 { return } s.mu.Lock() defer s.mu.Unlock() // 开启事务 tx, err := s.db.Begin() if err != nil { fmt.Printf("开启事务失败:%v\n", err) return } // 准备插入语句 stmt, err := tx.Prepare(` INSERT INTO query_logs ( timestamp, client_ip, domain, query_type, response_time, result, block_rule, block_type, from_cache, dnssec, edns, dns_server, dnssec_server, answers, response_code ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `) if err != nil { tx.Rollback() fmt.Printf("准备语句失败:%v\n", err) return } defer stmt.Close() // 批量插入 for _, log := range batch { _, err := stmt.Exec( log.Timestamp, log.ClientIP, log.Domain, log.QueryType, log.ResponseTime, log.Result, log.BlockRule, log.BlockType, log.FromCache, log.DNSSEC, log.EDNS, log.DNSServer, log.DNSSECServer, log.Answers, log.ResponseCode, ) if err != nil { fmt.Printf("插入日志失败:%v\n", err) continue } } // 提交事务 err = tx.Commit() if err != nil { fmt.Printf("提交事务失败:%v\n", err) } } // Log 记录日志 func (s *SQLiteStore) Log(log QueryLog) error { if s.closed { return fmt.Errorf("存储已关闭") } s.batchChan <- log return nil } // QueryLogs 查询日志 func (s *SQLiteStore) QueryLogs(filter LogFilter, page PageParams) ([]QueryLog, int64, error) { fmt.Printf("SQLiteStore.QueryLogs called: filter=%+v, page=%+v\n", filter, page) // 构建查询条件 whereClause := "1=1" args := []interface{}{} if filter.Result != "" { whereClause += " AND result = ?" args = append(args, filter.Result) } if filter.QueryType != "" { whereClause += " AND query_type = ?" args = append(args, filter.QueryType) } if !filter.StartTime.IsZero() { whereClause += " AND timestamp >= ?" args = append(args, filter.StartTime) } if !filter.EndTime.IsZero() { whereClause += " AND timestamp <= ?" args = append(args, filter.EndTime) } if filter.SearchTerm != "" { whereClause += " AND (domain LIKE ? OR client_ip LIKE ?)" searchTerm := "%" + filter.SearchTerm + "%" args = append(args, searchTerm, searchTerm) } fmt.Printf("SQLite WHERE clause: %s, args: %v\n", whereClause, args) // 获取总数 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM query_logs WHERE %s", whereClause) var total int64 err := s.db.QueryRow(countQuery, args...).Scan(&total) if err != nil { return nil, 0, fmt.Errorf("查询总数失败:%w", err) } fmt.Printf("SQLite total count: %d\n", total) // 构建排序 sortField := page.SortField if sortField == "" { sortField = "timestamp" } sortDirection := page.SortDirection if sortDirection == "" { sortDirection = "DESC" } // 查询日志 query := fmt.Sprintf(` SELECT id, timestamp, client_ip, domain, query_type, response_time, result, block_rule, block_type, from_cache, dnssec, edns, dns_server, dnssec_server, answers, response_code FROM query_logs WHERE %s ORDER BY %s %s LIMIT ? OFFSET ? `, whereClause, sortField, sortDirection) args = append(args, page.Limit, page.Offset) rows, err := s.db.Query(query, args...) if err != nil { return nil, 0, fmt.Errorf("查询日志失败:%w", err) } defer rows.Close() var logs []QueryLog for rows.Next() { var log QueryLog err := rows.Scan( &log.ID, &log.Timestamp, &log.ClientIP, &log.Domain, &log.QueryType, &log.ResponseTime, &log.Result, &log.BlockRule, &log.BlockType, &log.FromCache, &log.DNSSEC, &log.EDNS, &log.DNSServer, &log.DNSSECServer, &log.Answers, &log.ResponseCode, ) if err != nil { return nil, 0, fmt.Errorf("扫描日志失败:%w", err) } logs = append(logs, log) } return logs, total, nil } // GetStats 获取统计信息 func (s *SQLiteStore) GetStats(timeRange TimeRange) (*LogStats, error) { // 构建时间范围条件 whereClause := "1=1" args := []interface{}{} if !timeRange.StartTime.IsZero() { whereClause += " AND timestamp >= ?" args = append(args, timeRange.StartTime) } if !timeRange.EndTime.IsZero() { whereClause += " AND timestamp <= ?" args = append(args, timeRange.EndTime) } stats := &LogStats{ QueryTypes: make(map[string]int64), } // 基础统计 query := fmt.Sprintf(` SELECT COUNT(*) as total, SUM(CASE WHEN result = 'blocked' THEN 1 ELSE 0 END) as blocked, SUM(CASE WHEN result = 'allowed' THEN 1 ELSE 0 END) as allowed, SUM(CASE WHEN result = 'error' THEN 1 ELSE 0 END) as error, AVG(response_time) as avg_response_time FROM query_logs WHERE %s `, whereClause) err := s.db.QueryRow(query, args...).Scan( &stats.TotalQueries, &stats.BlockedQueries, &stats.AllowedQueries, &stats.ErrorQueries, &stats.AvgResponseTime, ) if err != nil { return nil, fmt.Errorf("查询基础统计失败:%w", err) } // 查询类型分布 query = fmt.Sprintf(` SELECT query_type, COUNT(*) as count FROM query_logs WHERE %s GROUP BY query_type `, whereClause) rows, err := s.db.Query(query, args...) if err != nil { return nil, fmt.Errorf("查询类型分布失败:%w", err) } defer rows.Close() for rows.Next() { var queryType string var count int64 err := rows.Scan(&queryType, &count) if err != nil { continue } stats.QueryTypes[queryType] = count } // TOP 域名 query = fmt.Sprintf(` SELECT domain, COUNT(*) as count FROM query_logs WHERE %s GROUP BY domain ORDER BY count DESC LIMIT 10 `, whereClause) rows, err = s.db.Query(query, args...) if err != nil { return nil, fmt.Errorf("查询 TOP 域名失败:%w", err) } defer rows.Close() for rows.Next() { var domain string var count int64 err := rows.Scan(&domain, &count) if err != nil { continue } stats.TopDomains = append(stats.TopDomains, DomainCount{Domain: domain, Count: count}) } // TOP 客户端 query = fmt.Sprintf(` SELECT client_ip, COUNT(*) as count FROM query_logs WHERE %s GROUP BY client_ip ORDER BY count DESC LIMIT 10 `, whereClause) rows, err = s.db.Query(query, args...) if err != nil { return nil, fmt.Errorf("查询 TOP 客户端失败:%w", err) } defer rows.Close() for rows.Next() { var ip string var count int64 err := rows.Scan(&ip, &count) if err != nil { continue } stats.TopClients = append(stats.TopClients, ClientCount{IP: ip, Count: count}) } return stats, nil } // Close 关闭存储 func (s *SQLiteStore) Close() error { s.mu.Lock() defer s.mu.Unlock() if s.closed { return nil } s.closed = true close(s.batchChan) // 等待批量写入完成 time.Sleep(200 * time.Millisecond) return s.db.Close() } // MigrateFromJSON 从 JSON 文件迁移数据 func (s *SQLiteStore) MigrateFromJSON(jsonPath string) error { // 读取 JSON 文件 data, err := readFile(jsonPath) if err != nil { return fmt.Errorf("读取 JSON 文件失败:%w", err) } var logs []QueryLog err = json.Unmarshal(data, &logs) if err != nil { return fmt.Errorf("解析 JSON 失败:%w", err) } if len(logs) == 0 { return nil } // 批量插入 tx, err := s.db.Begin() if err != nil { return fmt.Errorf("开启事务失败:%w", err) } stmt, err := tx.Prepare(` INSERT INTO query_logs ( timestamp, client_ip, domain, query_type, response_time, result, block_rule, block_type, from_cache, dnssec, edns, dns_server, dnssec_server, answers, response_code ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `) if err != nil { tx.Rollback() return fmt.Errorf("准备语句失败:%w", err) } defer stmt.Close() for _, log := range logs { _, err := stmt.Exec( log.Timestamp, log.ClientIP, log.Domain, log.QueryType, log.ResponseTime, log.Result, log.BlockRule, log.BlockType, log.FromCache, log.DNSSEC, log.EDNS, log.DNSServer, log.DNSSECServer, log.Answers, log.ResponseCode, ) if err != nil { fmt.Printf("迁移日志失败:%v\n", err) continue } } err = tx.Commit() if err != nil { return fmt.Errorf("提交事务失败:%w", err) } return nil } // readFile 读取文件内容(辅助函数) func readFile(path string) ([]byte, error) { return os.ReadFile(path) }