新建DNS服务器
This commit is contained in:
444
http/server.go
Normal file
444
http/server.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"dns-server/config"
|
||||
"dns-server/dns"
|
||||
"dns-server/logger"
|
||||
"dns-server/shield"
|
||||
)
|
||||
|
||||
// Server HTTP控制台服务器
|
||||
type Server struct {
|
||||
globalConfig *config.Config
|
||||
config *config.HTTPConfig
|
||||
dnsServer *dns.Server
|
||||
shieldManager *shield.ShieldManager
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewServer 创建HTTP服务器实例
|
||||
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
|
||||
return &Server{
|
||||
globalConfig: globalConfig,
|
||||
config: &globalConfig.HTTP,
|
||||
dnsServer: dnsServer,
|
||||
shieldManager: shieldManager,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动HTTP服务器
|
||||
func (s *Server) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// API路由
|
||||
if s.config.EnableAPI {
|
||||
mux.HandleFunc("/api/stats", s.handleStats)
|
||||
mux.HandleFunc("/api/shield", s.handleShield)
|
||||
mux.HandleFunc("/api/shield/hosts", s.handleShieldHosts)
|
||||
mux.HandleFunc("/api/query", s.handleQuery)
|
||||
mux.HandleFunc("/api/status", s.handleStatus)
|
||||
mux.HandleFunc("/api/config", s.handleConfig)
|
||||
// 添加统计相关接口
|
||||
mux.HandleFunc("/api/top-blocked", s.handleTopBlockedDomains)
|
||||
mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains)
|
||||
mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats)
|
||||
}
|
||||
|
||||
// 静态文件服务(可后续添加前端界面)
|
||||
mux.Handle("/", http.FileServer(http.Dir("./static")))
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port))
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Stop 停止HTTP服务器
|
||||
func (s *Server) Stop() {
|
||||
if s.server != nil {
|
||||
s.server.Close()
|
||||
}
|
||||
logger.Info("HTTP控制台服务器已停止")
|
||||
}
|
||||
|
||||
// handleStats 处理统计信息请求
|
||||
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
dnsStats := s.dnsServer.GetStats()
|
||||
shieldStats := s.shieldManager.GetStats()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"dns": dnsStats,
|
||||
"shield": shieldStats,
|
||||
"time": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
|
||||
// handleTopBlockedDomains 处理TOP屏蔽域名请求
|
||||
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
domains := s.dnsServer.GetTopBlockedDomains(10)
|
||||
|
||||
// 转换为前端需要的格式
|
||||
result := make([]map[string]interface{}, len(domains))
|
||||
for i, domain := range domains {
|
||||
result[i] = map[string]interface{}{
|
||||
"domain": domain.Domain,
|
||||
"count": domain.Count,
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleRecentBlockedDomains 处理最近屏蔽域名请求
|
||||
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
domains := s.dnsServer.GetRecentBlockedDomains(10)
|
||||
|
||||
// 转换为前端需要的格式
|
||||
result := make([]map[string]interface{}, len(domains))
|
||||
for i, domain := range domains {
|
||||
result[i] = map[string]interface{}{
|
||||
"domain": domain.Domain,
|
||||
"time": domain.LastSeen.Format("15:04:05"),
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleHourlyStats 处理24小时统计请求
|
||||
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
hourlyStats := s.dnsServer.GetHourlyStats()
|
||||
|
||||
// 生成最近24小时的数据
|
||||
now := time.Now()
|
||||
labels := make([]string, 24)
|
||||
data := make([]int64, 24)
|
||||
|
||||
for i := 23; i >= 0; i-- {
|
||||
hour := now.Add(time.Duration(-i) * time.Hour)
|
||||
hourKey := hour.Format("2006-01-02-15")
|
||||
labels[23-i] = hour.Format("15:00")
|
||||
data[23-i] = hourlyStats[hourKey]
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"labels": labels,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleShield 处理屏蔽规则管理请求
|
||||
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// 处理hosts管理子路由
|
||||
if strings.HasPrefix(r.URL.Path, "/api/shield/hosts") {
|
||||
s.handleShieldHosts(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 获取完整规则列表
|
||||
rules := s.shieldManager.GetRules()
|
||||
json.NewEncoder(w).Encode(rules)
|
||||
|
||||
case http.MethodPost:
|
||||
// 添加屏蔽规则
|
||||
var req struct {
|
||||
Rule string `json:"rule"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.shieldManager.AddRule(req.Rule); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
case http.MethodDelete:
|
||||
// 删除屏蔽规则
|
||||
var req struct {
|
||||
Rule string `json:"rule"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
case http.MethodPut:
|
||||
// 重新加载规则
|
||||
if err := s.shieldManager.LoadRules(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleShieldHosts 处理hosts管理请求
|
||||
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
// 添加hosts条目
|
||||
var req struct {
|
||||
IP string `json:"ip"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.IP == "" || req.Domain == "" {
|
||||
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
case http.MethodDelete:
|
||||
// 删除hosts条目
|
||||
var req struct {
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
case http.MethodGet:
|
||||
// 获取hosts条目列表
|
||||
// 注意:这需要在shieldManager中添加一个获取所有hosts条目的方法
|
||||
// 暂时返回统计信息
|
||||
stats := s.shieldManager.GetStats()
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"hostsCount": stats["hostsRules"],
|
||||
"message": "获取hosts列表功能待实现",
|
||||
})
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleQuery 处理DNS查询请求
|
||||
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
domain := r.URL.Query().Get("domain")
|
||||
if domain == "" {
|
||||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查域名是否被屏蔽
|
||||
blocked := s.shieldManager.IsBlocked(domain)
|
||||
|
||||
// 检查hosts文件是否有匹配
|
||||
hostsIP, hasHosts := s.shieldManager.GetHostsIP(domain)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"domain": domain,
|
||||
"blocked": blocked,
|
||||
"hasHosts": hasHosts,
|
||||
"hostsIP": hostsIP,
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleStatus 处理系统状态请求
|
||||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
stats := s.dnsServer.GetStats()
|
||||
status := map[string]interface{}{
|
||||
"status": "running",
|
||||
"queries": stats.Queries,
|
||||
"lastQuery": stats.LastQuery,
|
||||
"uptime": time.Since(stats.LastQuery),
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
// handleConfig 处理配置请求
|
||||
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 返回当前配置(只返回必要的部分)
|
||||
config := map[string]interface{}{
|
||||
"shield": map[string]string{
|
||||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||||
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(config)
|
||||
|
||||
case http.MethodPost:
|
||||
// 更新配置
|
||||
var req struct {
|
||||
Shield struct {
|
||||
BlockMethod string `json:"blockMethod"`
|
||||
CustomBlockIP string `json:"customBlockIP"`
|
||||
} `json:"shield"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "无效的请求体", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新屏蔽配置
|
||||
if req.Shield.BlockMethod != "" {
|
||||
// 验证屏蔽方法是否有效
|
||||
validMethods := map[string]bool{
|
||||
"NXDOMAIN": true,
|
||||
"refused": true,
|
||||
"emptyIP": true,
|
||||
"customIP": true,
|
||||
}
|
||||
|
||||
if !validMethods[req.Shield.BlockMethod] {
|
||||
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
|
||||
|
||||
// 如果选择了customIP,验证IP地址
|
||||
if req.Shield.BlockMethod == "customIP" {
|
||||
if req.Shield.CustomBlockIP == "" {
|
||||
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// 简单的IP地址验证
|
||||
if !isValidIP(req.Shield.CustomBlockIP) {
|
||||
http.Error(w, "无效的IP地址", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Shield.CustomBlockIP != "" {
|
||||
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "配置已更新",
|
||||
})
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// isValidIP 简单验证IP地址格式
|
||||
func isValidIP(ip string) bool {
|
||||
// 简单的IPv4地址验证
|
||||
parts := strings.Split(ip, ".")
|
||||
if len(parts) != 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, part := range parts {
|
||||
// 检查是否为数字
|
||||
for _, char := range part {
|
||||
if char < '0' || char > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// 检查数字范围
|
||||
var num int
|
||||
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user