|
| 1 | +package local_access |
| 2 | + |
| 3 | +import ( |
| 4 | + "net" |
| 5 | + "strings" |
| 6 | + |
| 7 | + "github.com/gofiber/fiber/v2" |
| 8 | + "github.com/limanmys/render-engine/pkg/logger" |
| 9 | +) |
| 10 | + |
| 11 | +// Config defines the configuration for local access middleware |
| 12 | +type Config struct { |
| 13 | + // AllowedNetworks defines the networks that are allowed to access the application |
| 14 | + // Default: []string{"127.0.0.1/32", "::1/128", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} |
| 15 | + AllowedNetworks []string |
| 16 | + |
| 17 | + // DenyMessage is the message returned when access is denied |
| 18 | + // Default: "Access denied: not from local network" |
| 19 | + DenyMessage string |
| 20 | + |
| 21 | + // StatusCode is the HTTP status code returned when access is denied |
| 22 | + // Default: 403 (Forbidden) |
| 23 | + StatusCode int |
| 24 | + |
| 25 | + // Next defines a function to skip this middleware when returned true. |
| 26 | + // Optional. Default: nil |
| 27 | + Next func(c *fiber.Ctx) bool |
| 28 | +} |
| 29 | + |
| 30 | +// ConfigDefault is the default config |
| 31 | +var ConfigDefault = Config{ |
| 32 | + AllowedNetworks: []string{ |
| 33 | + "127.0.0.1/32", // IPv4 localhost |
| 34 | + "::1/128", // IPv6 localhost |
| 35 | + "10.0.0.0/8", // Private network (Class A) |
| 36 | + "172.16.0.0/12", // Private network (Class B) |
| 37 | + "192.168.0.0/16", // Private network (Class C) |
| 38 | + }, |
| 39 | + DenyMessage: "Access denied: not from local network", |
| 40 | + StatusCode: fiber.StatusForbidden, |
| 41 | + Next: nil, |
| 42 | +} |
| 43 | + |
| 44 | +// Helper function to merge user config with default config |
| 45 | +func configDefault(config ...Config) Config { |
| 46 | + if len(config) < 1 { |
| 47 | + return ConfigDefault |
| 48 | + } |
| 49 | + |
| 50 | + cfg := config[0] |
| 51 | + |
| 52 | + if cfg.AllowedNetworks == nil { |
| 53 | + cfg.AllowedNetworks = ConfigDefault.AllowedNetworks |
| 54 | + } |
| 55 | + |
| 56 | + if cfg.DenyMessage == "" { |
| 57 | + cfg.DenyMessage = ConfigDefault.DenyMessage |
| 58 | + } |
| 59 | + |
| 60 | + if cfg.StatusCode == 0 { |
| 61 | + cfg.StatusCode = ConfigDefault.StatusCode |
| 62 | + } |
| 63 | + |
| 64 | + return cfg |
| 65 | +} |
| 66 | + |
| 67 | +// Creates a new local access middleware handler. |
| 68 | +// This middleware restricts access to certain routes based on the client's IP address. |
| 69 | +// It checks if the client IP is within the allowed networks (localhost and private networks by default). |
| 70 | +// Example usage: |
| 71 | +// app.Use(local_access.New()) |
| 72 | +// |
| 73 | +// Custom configuration: |
| 74 | +// |
| 75 | +// app.Use(local_access.New(local_access.Config{ |
| 76 | +// AllowedNetworks: []string{"192.168.1.0/24", "10.0.0.0/8"}, |
| 77 | +// DenyMessage: "Access denied from your network", |
| 78 | +// StatusCode: 403, |
| 79 | +// })) |
| 80 | +func New(config ...Config) fiber.Handler { |
| 81 | + cfg := configDefault(config...) |
| 82 | + |
| 83 | + // Parse allowed networks into CIDR blocks |
| 84 | + var allowedCIDRs []*net.IPNet |
| 85 | + for _, network := range cfg.AllowedNetworks { |
| 86 | + // Handle single IP addresses by adding appropriate CIDR suffix |
| 87 | + if !strings.Contains(network, "/") { |
| 88 | + if strings.Contains(network, ":") { |
| 89 | + // IPv6 address |
| 90 | + network += "/128" |
| 91 | + } else { |
| 92 | + // IPv4 address |
| 93 | + network += "/32" |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + _, cidr, err := net.ParseCIDR(network) |
| 98 | + if err != nil { |
| 99 | + logger.Sugar().Errorf("Invalid network CIDR '%s': %v", network, err) |
| 100 | + continue |
| 101 | + } |
| 102 | + allowedCIDRs = append(allowedCIDRs, cidr) |
| 103 | + } |
| 104 | + |
| 105 | + return func(c *fiber.Ctx) error { |
| 106 | + // Skip middleware if Next returns true |
| 107 | + if cfg.Next != nil && cfg.Next(c) { |
| 108 | + return c.Next() |
| 109 | + } |
| 110 | + |
| 111 | + // Get client IP address |
| 112 | + clientIP := c.IP() |
| 113 | + if clientIP == "" { |
| 114 | + logger.Sugar().Warn("Unable to determine client IP address") |
| 115 | + return logger.FiberError(cfg.StatusCode, cfg.DenyMessage) |
| 116 | + } |
| 117 | + |
| 118 | + // Parse client IP |
| 119 | + ip := net.ParseIP(clientIP) |
| 120 | + if ip == nil { |
| 121 | + logger.Sugar().Warnf("Invalid client IP address: %s", clientIP) |
| 122 | + return logger.FiberError(cfg.StatusCode, cfg.DenyMessage) |
| 123 | + } |
| 124 | + |
| 125 | + // Check if client IP is in any of the allowed networks |
| 126 | + for _, cidr := range allowedCIDRs { |
| 127 | + if cidr.Contains(ip) { |
| 128 | + // IP is allowed, continue to next middleware |
| 129 | + return c.Next() |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + // Log the denied access attempt |
| 134 | + logger.Sugar().Infof("Access denied for IP: %s", clientIP) |
| 135 | + |
| 136 | + // IP is not in allowed networks, deny access |
| 137 | + return logger.FiberError(cfg.StatusCode, cfg.DenyMessage) |
| 138 | + } |
| 139 | +} |
0 commit comments