Files
dns-server/log/sqlite_store.go
T
Alex Yang efebce3c39 whois
2026-04-01 12:22:55 +08:00

523 lines
11 KiB
Go

package log
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"sync"
"time"
_ "github.com/mattn/go-sqlite3"
)
// SQLiteStore SQLite 存储引擎
type SQLiteStore struct {
db *sql.DB
batchChan chan QueryLog
batchSize int
batchDelay time.Duration
mu sync.Mutex
closed bool
}
// NewSQLiteStore 创建 SQLite 存储
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
// 打开数据库连接
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("打开数据库失败:%w", err)
}
// 设置连接参数
db.SetMaxOpenConns(1) // SQLite 不支持高并发写入
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(time.Hour)
// 启用 WAL 模式
_, err = db.Exec("PRAGMA journal_mode=WAL")
if err != nil {
db.Close()
return nil, fmt.Errorf("启用 WAL 模式失败:%w", err)
}
// 创建表结构
err = createTables(db)
if err != nil {
db.Close()
return nil, fmt.Errorf("创建表结构失败:%w", err)
}
store := &SQLiteStore{
db: db,
batchChan: make(chan QueryLog, 10000),
batchSize: 100,
batchDelay: 100 * time.Millisecond,
}
// 启动批量写入协程
go store.batchWriter()
return store, nil
}
// createTables 创建数据库表
func createTables(db *sql.DB) error {
schema := `
CREATE TABLE IF NOT EXISTS query_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME NOT NULL,
client_ip TEXT NOT NULL,
domain TEXT NOT NULL,
query_type TEXT NOT NULL,
response_time INTEGER NOT NULL,
result TEXT NOT NULL,
block_rule TEXT,
block_type TEXT,
from_cache BOOLEAN DEFAULT FALSE,
dnssec BOOLEAN DEFAULT FALSE,
edns BOOLEAN DEFAULT FALSE,
dns_server TEXT,
dnssec_server TEXT,
answers TEXT,
response_code INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_timestamp ON query_logs(timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_domain ON query_logs(domain);
CREATE INDEX IF NOT EXISTS idx_client_ip ON query_logs(client_ip);
CREATE INDEX IF NOT EXISTS idx_result ON query_logs(result);
CREATE INDEX IF NOT EXISTS idx_query_type ON query_logs(query_type);
`
_, err := db.Exec(schema)
return err
}
// batchWriter 批量写入协程
func (s *SQLiteStore) batchWriter() {
batch := make([]QueryLog, 0, s.batchSize)
ticker := time.NewTicker(s.batchDelay)
defer ticker.Stop()
for {
select {
case log, ok := <-s.batchChan:
if !ok {
// 通道关闭,写入剩余日志
if len(batch) > 0 {
s.writeBatch(batch)
}
return
}
batch = append(batch, log)
// 达到批量大小时立即写入
if len(batch) >= s.batchSize {
s.writeBatch(batch)
batch = batch[:0]
}
case <-ticker.C:
// 定时写入
if len(batch) > 0 {
s.writeBatch(batch)
batch = batch[:0]
}
}
}
}
// writeBatch 批量写入日志
func (s *SQLiteStore) writeBatch(batch []QueryLog) {
if len(batch) == 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()
// 开启事务
tx, err := s.db.Begin()
if err != nil {
fmt.Printf("开启事务失败:%v\n", err)
return
}
// 准备插入语句
stmt, err := tx.Prepare(`
INSERT INTO query_logs (
timestamp, client_ip, domain, query_type, response_time,
result, block_rule, block_type, from_cache, dnssec, edns,
dns_server, dnssec_server, answers, response_code
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
tx.Rollback()
fmt.Printf("准备语句失败:%v\n", err)
return
}
defer stmt.Close()
// 批量插入
for _, log := range batch {
_, err := stmt.Exec(
log.Timestamp,
log.ClientIP,
log.Domain,
log.QueryType,
log.ResponseTime,
log.Result,
log.BlockRule,
log.BlockType,
log.FromCache,
log.DNSSEC,
log.EDNS,
log.DNSServer,
log.DNSSECServer,
log.Answers,
log.ResponseCode,
)
if err != nil {
fmt.Printf("插入日志失败:%v\n", err)
continue
}
}
// 提交事务
err = tx.Commit()
if err != nil {
fmt.Printf("提交事务失败:%v\n", err)
}
}
// Log 记录日志
func (s *SQLiteStore) Log(log QueryLog) error {
if s.closed {
return fmt.Errorf("存储已关闭")
}
s.batchChan <- log
return nil
}
// QueryLogs 查询日志
func (s *SQLiteStore) QueryLogs(filter LogFilter, page PageParams) ([]QueryLog, int64, error) {
fmt.Printf("SQLiteStore.QueryLogs called: filter=%+v, page=%+v\n", filter, page)
// 构建查询条件
whereClause := "1=1"
args := []interface{}{}
if filter.Result != "" {
whereClause += " AND result = ?"
args = append(args, filter.Result)
}
if filter.QueryType != "" {
whereClause += " AND query_type = ?"
args = append(args, filter.QueryType)
}
if !filter.StartTime.IsZero() {
whereClause += " AND timestamp >= ?"
args = append(args, filter.StartTime)
}
if !filter.EndTime.IsZero() {
whereClause += " AND timestamp <= ?"
args = append(args, filter.EndTime)
}
if filter.SearchTerm != "" {
whereClause += " AND (domain LIKE ? OR client_ip LIKE ?)"
searchTerm := "%" + filter.SearchTerm + "%"
args = append(args, searchTerm, searchTerm)
}
fmt.Printf("SQLite WHERE clause: %s, args: %v\n", whereClause, args)
// 获取总数
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM query_logs WHERE %s", whereClause)
var total int64
err := s.db.QueryRow(countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("查询总数失败:%w", err)
}
fmt.Printf("SQLite total count: %d\n", total)
// 构建排序
sortField := page.SortField
if sortField == "" {
sortField = "timestamp"
}
sortDirection := page.SortDirection
if sortDirection == "" {
sortDirection = "DESC"
}
// 查询日志
query := fmt.Sprintf(`
SELECT id, timestamp, client_ip, domain, query_type, response_time,
result, block_rule, block_type, from_cache, dnssec, edns,
dns_server, dnssec_server, answers, response_code
FROM query_logs
WHERE %s
ORDER BY %s %s
LIMIT ? OFFSET ?
`, whereClause, sortField, sortDirection)
args = append(args, page.Limit, page.Offset)
rows, err := s.db.Query(query, args...)
if err != nil {
return nil, 0, fmt.Errorf("查询日志失败:%w", err)
}
defer rows.Close()
var logs []QueryLog
for rows.Next() {
var log QueryLog
err := rows.Scan(
&log.ID,
&log.Timestamp,
&log.ClientIP,
&log.Domain,
&log.QueryType,
&log.ResponseTime,
&log.Result,
&log.BlockRule,
&log.BlockType,
&log.FromCache,
&log.DNSSEC,
&log.EDNS,
&log.DNSServer,
&log.DNSSECServer,
&log.Answers,
&log.ResponseCode,
)
if err != nil {
return nil, 0, fmt.Errorf("扫描日志失败:%w", err)
}
logs = append(logs, log)
}
return logs, total, nil
}
// GetStats 获取统计信息
func (s *SQLiteStore) GetStats(timeRange TimeRange) (*LogStats, error) {
// 构建时间范围条件
whereClause := "1=1"
args := []interface{}{}
if !timeRange.StartTime.IsZero() {
whereClause += " AND timestamp >= ?"
args = append(args, timeRange.StartTime)
}
if !timeRange.EndTime.IsZero() {
whereClause += " AND timestamp <= ?"
args = append(args, timeRange.EndTime)
}
stats := &LogStats{
QueryTypes: make(map[string]int64),
}
// 基础统计
query := fmt.Sprintf(`
SELECT
COUNT(*) as total,
SUM(CASE WHEN result = 'blocked' THEN 1 ELSE 0 END) as blocked,
SUM(CASE WHEN result = 'allowed' THEN 1 ELSE 0 END) as allowed,
SUM(CASE WHEN result = 'error' THEN 1 ELSE 0 END) as error,
AVG(response_time) as avg_response_time
FROM query_logs
WHERE %s
`, whereClause)
err := s.db.QueryRow(query, args...).Scan(
&stats.TotalQueries,
&stats.BlockedQueries,
&stats.AllowedQueries,
&stats.ErrorQueries,
&stats.AvgResponseTime,
)
if err != nil {
return nil, fmt.Errorf("查询基础统计失败:%w", err)
}
// 查询类型分布
query = fmt.Sprintf(`
SELECT query_type, COUNT(*) as count
FROM query_logs
WHERE %s
GROUP BY query_type
`, whereClause)
rows, err := s.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询类型分布失败:%w", err)
}
defer rows.Close()
for rows.Next() {
var queryType string
var count int64
err := rows.Scan(&queryType, &count)
if err != nil {
continue
}
stats.QueryTypes[queryType] = count
}
// TOP 域名
query = fmt.Sprintf(`
SELECT domain, COUNT(*) as count
FROM query_logs
WHERE %s
GROUP BY domain
ORDER BY count DESC
LIMIT 10
`, whereClause)
rows, err = s.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询 TOP 域名失败:%w", err)
}
defer rows.Close()
for rows.Next() {
var domain string
var count int64
err := rows.Scan(&domain, &count)
if err != nil {
continue
}
stats.TopDomains = append(stats.TopDomains, DomainCount{Domain: domain, Count: count})
}
// TOP 客户端
query = fmt.Sprintf(`
SELECT client_ip, COUNT(*) as count
FROM query_logs
WHERE %s
GROUP BY client_ip
ORDER BY count DESC
LIMIT 10
`, whereClause)
rows, err = s.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询 TOP 客户端失败:%w", err)
}
defer rows.Close()
for rows.Next() {
var ip string
var count int64
err := rows.Scan(&ip, &count)
if err != nil {
continue
}
stats.TopClients = append(stats.TopClients, ClientCount{IP: ip, Count: count})
}
return stats, nil
}
// Close 关闭存储
func (s *SQLiteStore) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return nil
}
s.closed = true
close(s.batchChan)
// 等待批量写入完成
time.Sleep(200 * time.Millisecond)
return s.db.Close()
}
// MigrateFromJSON 从 JSON 文件迁移数据
func (s *SQLiteStore) MigrateFromJSON(jsonPath string) error {
// 读取 JSON 文件
data, err := readFile(jsonPath)
if err != nil {
return fmt.Errorf("读取 JSON 文件失败:%w", err)
}
var logs []QueryLog
err = json.Unmarshal(data, &logs)
if err != nil {
return fmt.Errorf("解析 JSON 失败:%w", err)
}
if len(logs) == 0 {
return nil
}
// 批量插入
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败:%w", err)
}
stmt, err := tx.Prepare(`
INSERT INTO query_logs (
timestamp, client_ip, domain, query_type, response_time,
result, block_rule, block_type, from_cache, dnssec, edns,
dns_server, dnssec_server, answers, response_code
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`)
if err != nil {
tx.Rollback()
return fmt.Errorf("准备语句失败:%w", err)
}
defer stmt.Close()
for _, log := range logs {
_, err := stmt.Exec(
log.Timestamp,
log.ClientIP,
log.Domain,
log.QueryType,
log.ResponseTime,
log.Result,
log.BlockRule,
log.BlockType,
log.FromCache,
log.DNSSEC,
log.EDNS,
log.DNSServer,
log.DNSSECServer,
log.Answers,
log.ResponseCode,
)
if err != nil {
fmt.Printf("迁移日志失败:%v\n", err)
continue
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("提交事务失败:%w", err)
}
return nil
}
// readFile 读取文件内容(辅助函数)
func readFile(path string) ([]byte, error) {
return os.ReadFile(path)
}