768 lines
21 KiB
Go
768 lines
21 KiB
Go
package http
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io/ioutil"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"dns-server/config"
|
||
"dns-server/dns"
|
||
"dns-server/logger"
|
||
"dns-server/shield"
|
||
)
|
||
|
||
// Server HTTP控制台服务器
|
||
type Server struct {
|
||
globalConfig *config.Config
|
||
config *config.HTTPConfig
|
||
dnsServer *dns.Server
|
||
shieldManager *shield.ShieldManager
|
||
server *http.Server
|
||
}
|
||
|
||
// NewServer 创建HTTP服务器实例
|
||
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
|
||
return &Server{
|
||
globalConfig: globalConfig,
|
||
config: &globalConfig.HTTP,
|
||
dnsServer: dnsServer,
|
||
shieldManager: shieldManager,
|
||
}
|
||
}
|
||
|
||
// Start 启动HTTP服务器
|
||
func (s *Server) Start() error {
|
||
mux := http.NewServeMux()
|
||
|
||
// API路由
|
||
if s.config.EnableAPI {
|
||
mux.HandleFunc("/api/stats", s.handleStats)
|
||
mux.HandleFunc("/api/shield", s.handleShield)
|
||
mux.HandleFunc("/api/shield/hosts", s.handleShieldHosts)
|
||
mux.HandleFunc("/api/shield/blacklists", s.handleShieldBlacklists)
|
||
mux.HandleFunc("/api/query", s.handleQuery)
|
||
mux.HandleFunc("/api/status", s.handleStatus)
|
||
mux.HandleFunc("/api/config", s.handleConfig)
|
||
// 添加统计相关接口
|
||
mux.HandleFunc("/api/top-blocked", s.handleTopBlockedDomains)
|
||
mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains)
|
||
mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains)
|
||
mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats)
|
||
mux.HandleFunc("/api/daily-stats", s.handleDailyStats)
|
||
mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats)
|
||
}
|
||
|
||
// 静态文件服务(可后续添加前端界面)
|
||
mux.Handle("/", http.FileServer(http.Dir("./static")))
|
||
|
||
s.server = &http.Server{
|
||
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
|
||
Handler: mux,
|
||
ReadTimeout: 10 * time.Second,
|
||
WriteTimeout: 10 * time.Second,
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port))
|
||
return s.server.ListenAndServe()
|
||
}
|
||
|
||
// Stop 停止HTTP服务器
|
||
func (s *Server) Stop() {
|
||
if s.server != nil {
|
||
s.server.Close()
|
||
}
|
||
logger.Info("HTTP控制台服务器已停止")
|
||
}
|
||
|
||
// handleStats 处理统计信息请求
|
||
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
dnsStats := s.dnsServer.GetStats()
|
||
shieldStats := s.shieldManager.GetStats()
|
||
|
||
// 获取最常用查询类型(如果有)
|
||
topQueryType := "-"
|
||
maxCount := int64(0)
|
||
if len(dnsStats.QueryTypes) > 0 {
|
||
for queryType, count := range dnsStats.QueryTypes {
|
||
if count > maxCount {
|
||
maxCount = count
|
||
topQueryType = queryType
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取活跃来源IP数量
|
||
activeIPCount := len(dnsStats.SourceIPs)
|
||
|
||
// 构建响应数据,确保所有字段都反映服务器的真实状态
|
||
stats := map[string]interface{}{
|
||
"dns": map[string]interface{}{
|
||
"Queries": dnsStats.Queries,
|
||
"Blocked": dnsStats.Blocked,
|
||
"Allowed": dnsStats.Allowed,
|
||
"Errors": dnsStats.Errors,
|
||
"LastQuery": dnsStats.LastQuery,
|
||
"AvgResponseTime": dnsStats.AvgResponseTime,
|
||
"TotalResponseTime": dnsStats.TotalResponseTime,
|
||
"QueryTypes": dnsStats.QueryTypes,
|
||
"SourceIPs": dnsStats.SourceIPs,
|
||
"CpuUsage": dnsStats.CpuUsage,
|
||
},
|
||
"shield": shieldStats,
|
||
"topQueryType": topQueryType,
|
||
"activeIPs": activeIPCount,
|
||
"avgResponseTime": dnsStats.AvgResponseTime,
|
||
"cpuUsage": dnsStats.CpuUsage,
|
||
"time": time.Now(),
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(stats)
|
||
}
|
||
|
||
// handleTopBlockedDomains 处理TOP屏蔽域名请求
|
||
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domains := s.dnsServer.GetTopBlockedDomains(10)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(domains))
|
||
for i, domain := range domains {
|
||
result[i] = map[string]interface{}{
|
||
"domain": domain.Domain,
|
||
"count": domain.Count,
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleTopResolvedDomains 处理获取最常解析的域名请求
|
||
func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domains := s.dnsServer.GetTopResolvedDomains(10)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(domains))
|
||
for i, domain := range domains {
|
||
result[i] = map[string]interface{}{
|
||
"domain": domain.Domain,
|
||
"count": domain.Count,
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleRecentBlockedDomains 处理最近屏蔽域名请求
|
||
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domains := s.dnsServer.GetRecentBlockedDomains(10)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(domains))
|
||
for i, domain := range domains {
|
||
result[i] = map[string]interface{}{
|
||
"domain": domain.Domain,
|
||
"time": domain.LastSeen.Format("15:04:05"),
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleHourlyStats 处理24小时统计请求
|
||
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
hourlyStats := s.dnsServer.GetHourlyStats()
|
||
|
||
// 生成最近24小时的数据
|
||
now := time.Now()
|
||
labels := make([]string, 24)
|
||
data := make([]int64, 24)
|
||
|
||
for i := 23; i >= 0; i-- {
|
||
hour := now.Add(time.Duration(-i) * time.Hour)
|
||
hourKey := hour.Format("2006-01-02-15")
|
||
labels[23-i] = hour.Format("15:00")
|
||
data[23-i] = hourlyStats[hourKey]
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"labels": labels,
|
||
"data": data,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleDailyStats 处理每日统计数据请求
|
||
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取每日统计数据
|
||
dailyStats := s.dnsServer.GetDailyStats()
|
||
|
||
// 生成过去7天的时间标签
|
||
labels := make([]string, 7)
|
||
data := make([]int64, 7)
|
||
now := time.Now()
|
||
|
||
for i := 6; i >= 0; i-- {
|
||
t := now.AddDate(0, 0, -i)
|
||
key := t.Format("2006-01-02")
|
||
labels[6-i] = t.Format("01-02")
|
||
data[6-i] = dailyStats[key]
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"labels": labels,
|
||
"data": data,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleMonthlyStats 处理每月统计数据请求
|
||
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取每日统计数据(用于30天视图)
|
||
dailyStats := s.dnsServer.GetDailyStats()
|
||
|
||
// 生成过去30天的时间标签
|
||
labels := make([]string, 30)
|
||
data := make([]int64, 30)
|
||
now := time.Now()
|
||
|
||
for i := 29; i >= 0; i-- {
|
||
t := now.AddDate(0, 0, -i)
|
||
key := t.Format("2006-01-02")
|
||
labels[29-i] = t.Format("01-02")
|
||
data[29-i] = dailyStats[key]
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"labels": labels,
|
||
"data": data,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleShield 处理屏蔽规则管理请求
|
||
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
// 返回屏蔽规则的基本配置信息和统计数据,不返回完整规则列表
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 获取规则统计信息
|
||
stats := s.shieldManager.GetStats()
|
||
shieldInfo := map[string]interface{}{
|
||
"updateInterval": s.globalConfig.Shield.UpdateInterval,
|
||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
|
||
"domainRulesCount": stats["domainRules"],
|
||
"domainExceptionsCount": stats["domainExceptions"],
|
||
"regexRulesCount": stats["regexRules"],
|
||
"regexExceptionsCount": stats["regexExceptions"],
|
||
"hostsRulesCount": stats["hostsRules"],
|
||
}
|
||
json.NewEncoder(w).Encode(shieldInfo)
|
||
return
|
||
case http.MethodPost:
|
||
// 添加屏蔽规则
|
||
var req struct {
|
||
Rule string `json:"rule"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.AddRule(req.Rule); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
case http.MethodDelete:
|
||
// 删除屏蔽规则
|
||
var req struct {
|
||
Rule string `json:"rule"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
case http.MethodPut:
|
||
// 重新加载规则
|
||
if err := s.shieldManager.LoadRules(); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
|
||
return
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
}
|
||
|
||
// handleShieldBlacklists 处理远程黑名单管理请求
|
||
func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
// 处理更新单个黑名单
|
||
if strings.Contains(r.URL.Path, "/update") {
|
||
if r.Method == http.MethodPost {
|
||
// 提取黑名单URL或Name
|
||
parts := strings.Split(r.URL.Path, "/")
|
||
var targetURLOrName string
|
||
for i, part := range parts {
|
||
if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" {
|
||
targetURLOrName = parts[i+1]
|
||
break
|
||
}
|
||
}
|
||
|
||
if targetURLOrName == "" {
|
||
http.Error(w, "黑名单标识不能为空", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 获取黑名单列表
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
var targetIndex = -1
|
||
for i, list := range blacklists {
|
||
if list.URL == targetURLOrName || list.Name == targetURLOrName {
|
||
targetIndex = i
|
||
break
|
||
}
|
||
}
|
||
|
||
if targetIndex == -1 {
|
||
http.Error(w, "黑名单不存在", http.StatusNotFound)
|
||
return
|
||
}
|
||
|
||
// 更新时间戳
|
||
blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339)
|
||
// 保存更新后的黑名单列表
|
||
s.shieldManager.UpdateBlacklist(blacklists)
|
||
// 重新加载规则以获取最新的远程规则
|
||
s.shieldManager.LoadRules()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
}
|
||
}
|
||
|
||
// 处理删除黑名单
|
||
parts := strings.Split(r.URL.Path, "/")
|
||
if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete {
|
||
id := parts[4]
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
var newBlacklists []config.BlacklistEntry
|
||
|
||
for _, list := range blacklists {
|
||
if list.URL != id && list.Name != id {
|
||
newBlacklists = append(newBlacklists, list)
|
||
}
|
||
}
|
||
|
||
s.shieldManager.UpdateBlacklist(newBlacklists)
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
}
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 获取远程黑名单列表
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
json.NewEncoder(w).Encode(blacklists)
|
||
|
||
case http.MethodPost:
|
||
// 添加远程黑名单
|
||
var req struct {
|
||
Name string `json:"name"`
|
||
URL string `json:"url"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.Name == "" || req.URL == "" {
|
||
http.Error(w, "Name and URL are required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 获取现有黑名单
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
|
||
// 检查是否已存在
|
||
for _, list := range blacklists {
|
||
if list.URL == req.URL {
|
||
http.Error(w, "黑名单URL已存在", http.StatusConflict)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 添加新黑名单
|
||
newEntry := config.BlacklistEntry{
|
||
Name: req.Name,
|
||
URL: req.URL,
|
||
Enabled: true,
|
||
}
|
||
|
||
blacklists = append(blacklists, newEntry)
|
||
s.shieldManager.UpdateBlacklist(blacklists)
|
||
|
||
// 重新加载规则以获取新添加的远程规则
|
||
s.shieldManager.LoadRules()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodPut:
|
||
// 更新所有远程黑名单
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
for i := range blacklists {
|
||
// 更新每个黑名单的时间戳
|
||
blacklists[i].LastUpdateTime = time.Now().Format(time.RFC3339)
|
||
}
|
||
|
||
s.shieldManager.UpdateBlacklist(blacklists)
|
||
// 重新加载所有规则
|
||
s.shieldManager.LoadRules()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// handleShieldHosts 处理hosts管理请求
|
||
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
switch r.Method {
|
||
case http.MethodPost:
|
||
// 添加hosts条目
|
||
var req struct {
|
||
IP string `json:"ip"`
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.IP == "" || req.Domain == "" {
|
||
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodDelete:
|
||
// 删除hosts条目
|
||
var req struct {
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodGet:
|
||
// 获取hosts条目列表
|
||
hosts := s.shieldManager.GetAllHosts()
|
||
hostsCount := s.shieldManager.GetHostsCount()
|
||
|
||
// 转换为数组格式,便于前端展示
|
||
hostsList := make([]map[string]string, 0, len(hosts))
|
||
for domain, ip := range hosts {
|
||
hostsList = append(hostsList, map[string]string{
|
||
"domain": domain,
|
||
"ip": ip,
|
||
})
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"hosts": hostsList,
|
||
"hostsCount": hostsCount,
|
||
})
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// handleQuery 处理DNS查询请求
|
||
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domain := r.URL.Query().Get("domain")
|
||
if domain == "" {
|
||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 获取域名屏蔽的详细信息
|
||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||
|
||
// 添加时间戳
|
||
blockDetails["timestamp"] = time.Now()
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(blockDetails)
|
||
}
|
||
|
||
// handleStatus 处理系统状态请求
|
||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
stats := s.dnsServer.GetStats()
|
||
|
||
// 使用服务器的实际启动时间计算准确的运行时间
|
||
serverStartTime := s.dnsServer.GetStartTime()
|
||
uptime := time.Since(serverStartTime)
|
||
|
||
// 构建包含所有真实服务器统计数据的响应
|
||
status := map[string]interface{}{
|
||
"status": "running",
|
||
"queries": stats.Queries,
|
||
"blocked": stats.Blocked,
|
||
"allowed": stats.Allowed,
|
||
"errors": stats.Errors,
|
||
"lastQuery": stats.LastQuery,
|
||
"avgResponseTime": stats.AvgResponseTime,
|
||
"activeIPs": len(stats.SourceIPs),
|
||
"startTime": serverStartTime,
|
||
"uptime": uptime,
|
||
"cpuUsage": stats.CpuUsage,
|
||
"timestamp": time.Now(),
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(status)
|
||
}
|
||
|
||
// saveConfigToFile 保存配置到文件
|
||
func saveConfigToFile(config *config.Config, filePath string) error {
|
||
data, err := json.MarshalIndent(config, "", " ")
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return ioutil.WriteFile(filePath, data, 0644)
|
||
}
|
||
|
||
// handleConfig 处理配置请求
|
||
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 返回当前配置(包括黑名单配置)
|
||
config := map[string]interface{}{
|
||
"shield": map[string]interface{}{
|
||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
|
||
"blacklists": s.globalConfig.Shield.Blacklists,
|
||
"updateInterval": s.globalConfig.Shield.UpdateInterval,
|
||
},
|
||
}
|
||
json.NewEncoder(w).Encode(config)
|
||
|
||
case http.MethodPost:
|
||
// 更新配置
|
||
var req struct {
|
||
Shield struct {
|
||
BlockMethod string `json:"blockMethod"`
|
||
CustomBlockIP string `json:"customBlockIP"`
|
||
Blacklists []config.BlacklistEntry `json:"blacklists"`
|
||
UpdateInterval int `json:"updateInterval"`
|
||
} `json:"shield"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "无效的请求体", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 更新屏蔽配置
|
||
if req.Shield.BlockMethod != "" {
|
||
// 验证屏蔽方法是否有效
|
||
validMethods := map[string]bool{
|
||
"NXDOMAIN": true,
|
||
"refused": true,
|
||
"emptyIP": true,
|
||
"customIP": true,
|
||
}
|
||
|
||
if !validMethods[req.Shield.BlockMethod] {
|
||
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
|
||
|
||
// 如果选择了customIP,验证IP地址
|
||
if req.Shield.BlockMethod == "customIP" {
|
||
if req.Shield.CustomBlockIP == "" {
|
||
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
|
||
return
|
||
}
|
||
// 简单的IP地址验证
|
||
if !isValidIP(req.Shield.CustomBlockIP) {
|
||
http.Error(w, "无效的IP地址", http.StatusBadRequest)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
if req.Shield.CustomBlockIP != "" {
|
||
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
|
||
}
|
||
|
||
// 更新黑名单配置
|
||
if req.Shield.Blacklists != nil {
|
||
// 验证黑名单配置
|
||
for i, bl := range req.Shield.Blacklists {
|
||
if bl.URL == "" {
|
||
http.Error(w, fmt.Sprintf("黑名单URL不能为空,索引: %d", i), http.StatusBadRequest)
|
||
return
|
||
}
|
||
if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") {
|
||
http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest)
|
||
return
|
||
}
|
||
}
|
||
s.globalConfig.Shield.Blacklists = req.Shield.Blacklists
|
||
s.shieldManager.UpdateBlacklist(req.Shield.Blacklists)
|
||
// 重新加载规则
|
||
if err := s.shieldManager.LoadRules(); err != nil {
|
||
logger.Error("重新加载规则失败", "error", err)
|
||
}
|
||
}
|
||
|
||
// 更新更新间隔
|
||
if req.Shield.UpdateInterval > 0 {
|
||
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
|
||
// 重新启动自动更新
|
||
s.shieldManager.StopAutoUpdate()
|
||
s.shieldManager.StartAutoUpdate()
|
||
}
|
||
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
|
||
logger.Error("保存配置到文件失败", "error", err)
|
||
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
|
||
}
|
||
|
||
// 返回成功响应
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": true,
|
||
"message": "配置已更新",
|
||
})
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// isValidIP 简单验证IP地址格式
|
||
func isValidIP(ip string) bool {
|
||
// 简单的IPv4地址验证
|
||
parts := strings.Split(ip, ".")
|
||
if len(parts) != 4 {
|
||
return false
|
||
}
|
||
|
||
for _, part := range parts {
|
||
// 检查是否为数字
|
||
for _, char := range part {
|
||
if char < '0' || char > '9' {
|
||
return false
|
||
}
|
||
}
|
||
// 检查数字范围
|
||
var num int
|
||
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|