Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions internal/geo/geo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package geo
import (
"context"
"net/netip"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetCityAndASN(t *testing.T) {
Expand All @@ -20,6 +23,9 @@ func TestGetCityAndASN(t *testing.T) {

g.Init(ctx)

// Verify Init succeeded
assert.True(t, g.IsValid(), "GeoDatabase should be valid after successful Init")

ip := netip.MustParseAddr("2.125.160.216")
city, err := g.GetCity(ip)
if err != nil {
Expand Down Expand Up @@ -51,3 +57,136 @@ func TestGetCityAndASN(t *testing.T) {
continentName = city.Continent.Names.English
assert.Empty(t, continentName, "Expected empty continent name, got '%s'", continentName)
}

func TestInit_EmptyPaths(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := &GeoDatabase{
ASNPath: "",
CityPath: "",
}

g.Init(ctx)

// Should be invalid when both paths are empty
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when both paths are empty")

// GetCity should return error
ip := netip.MustParseAddr("1.1.1.1")
_, err := g.GetCity(ip)
require.Error(t, err)
assert.Equal(t, ErrNotValidConfig, err)
}

func TestInit_MissingFiles(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := &GeoDatabase{
ASNPath: "/nonexistent/path/to/ASN.mmdb",
CityPath: "/nonexistent/path/to/City.mmdb",
}

g.Init(ctx)

// Should be invalid when files don't exist
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when files don't exist")

// GetCity should return error
ip := netip.MustParseAddr("1.1.1.1")
_, err := g.GetCity(ip)
require.Error(t, err)
assert.Equal(t, ErrNotValidConfig, err)
}

func TestInit_OneValidPath(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Test with only ASN path
g := &GeoDatabase{
ASNPath: filepath.Join("test_data", "GeoLite2-ASN.mmdb"),
CityPath: "",
}

g.Init(ctx)
assert.True(t, g.IsValid(), "GeoDatabase should be valid with only ASN path")

ip := netip.MustParseAddr("1.0.0.1")
asn, err := g.GetASN(ip)
require.NoError(t, err)
assert.Equal(t, uint(15169), asn.AutonomousSystemNumber)

// City should return empty record, not error
city, err := g.GetCity(ip)
require.NoError(t, err)
assert.NotNil(t, city)
}

func TestInit_InvalidPathIsDirectory(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Use test_data directory as an invalid path (it's a directory, not a file)
g := &GeoDatabase{
ASNPath: "test_data",
CityPath: filepath.Join("test_data", "GeoLite2-City.mmdb"),
}

g.Init(ctx)

// Should be invalid when path is a directory
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when path is a directory")
}

func TestInit_WatchFilesGoroutine(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := &GeoDatabase{
ASNPath: filepath.Join("test_data", "GeoLite2-ASN.mmdb"),
CityPath: filepath.Join("test_data", "GeoLite2-City.mmdb"),
}

g.Init(ctx)
assert.True(t, g.IsValid(), "GeoDatabase should be valid after successful Init")

// Give WatchFiles goroutine a moment to initialize
time.Sleep(100 * time.Millisecond)

// Verify lastModTime map is initialized
g.RLock()
assert.NotNil(t, g.lastModTime, "lastModTime map should be initialized")
g.RUnlock()

// Cancel context to stop WatchFiles goroutine
cancel()

// Give it a moment to clean up
time.Sleep(100 * time.Millisecond)

// Database should still be valid after context cancellation
assert.True(t, g.IsValid(), "GeoDatabase should remain valid after context cancellation")
}

func TestInit_EmptyFile(t *testing.T) {
// Create a temporary empty file
tmpFile, err := os.CreateTemp(t.TempDir(), "empty-*.mmdb")
require.NoError(t, err)
tmpPath := tmpFile.Name()
tmpFile.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g := &GeoDatabase{
ASNPath: tmpPath,
CityPath: "",
}

g.Init(ctx)

// Should be invalid when file is empty
assert.False(t, g.IsValid(), "GeoDatabase should be invalid when file is empty")
}
133 changes: 95 additions & 38 deletions internal/geo/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ func (g *GeoDatabase) Init(ctx context.Context) {
return
}

// Validate paths exist before attempting to load
if err := g.validatePaths(); err != nil {
log.Errorf("geo database path validation failed: %s", err)
g.loadFailed = true
return
}

if err := g.open(); err != nil {
log.Errorf("failed to open databases: %s", err)
g.loadFailed = true
Expand All @@ -50,25 +57,65 @@ func (g *GeoDatabase) reload() error {
return g.open()
}

func (g *GeoDatabase) open() error {
if !g.IsValid() {
return ErrNotValidConfig
// validatePaths checks if the configured database paths exist and are readable
func (g *GeoDatabase) validatePaths() error {
if g.ASNPath != "" {
if err := g.validatePath(g.ASNPath, "ASN"); err != nil {
return err
}
}

if g.CityPath != "" {
if err := g.validatePath(g.CityPath, "City"); err != nil {
return err
}
}

return nil
}

// validatePath checks if a single database path exists and is readable
func (g *GeoDatabase) validatePath(path, dbType string) error {
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("%s database file does not exist: %s", dbType, path)
}
return fmt.Errorf("failed to stat %s database file %s: %w", dbType, path, err)
}

if info.IsDir() {
return fmt.Errorf("%s database path is a directory, not a file: %s", dbType, path)
}

if info.Size() == 0 {
return fmt.Errorf("%s database file is empty: %s", dbType, path)
}

return nil
}

func (g *GeoDatabase) open() error {
g.Lock()
defer g.Unlock()

var err error
if g.asnReader == nil && g.ASNPath != "" {
g.asnReader, err = geoip2.Open(g.ASNPath)
if err != nil {
return err
return fmt.Errorf("failed to open ASN database: %w", err)
}
}

if g.cityReader == nil && g.CityPath != "" {
g.cityReader, err = geoip2.Open(g.CityPath)
if err != nil {
return err
// Clean up ASN reader if it was opened successfully
if g.asnReader != nil {
g.asnReader.Close()
g.asnReader = nil
}
return fmt.Errorf("failed to open City database: %w", err)
}
}

Expand Down Expand Up @@ -158,51 +205,29 @@ func GetIsoCodeFromRecord(record *geoip2.City) string {
func (g *GeoDatabase) WatchFiles(ctx context.Context) {

ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
shouldUpdate := false
if asnLastModTime, ok := g.lastModTime[g.ASNPath]; ok {
info, err := os.Stat(g.ASNPath)
if err != nil {
log.Warnf("failed to stat ASN database: %s", err)
continue
}
if info.ModTime().After(asnLastModTime) {
log.Infof("ASN database has been updated, reloading")

// Check ASN database
if g.ASNPath != "" {
if updated := g.checkAndUpdateModTime(g.ASNPath, "ASN"); updated {
shouldUpdate = true
g.lastModTime[g.ASNPath] = info.ModTime()
}
} else {
info, err := os.Stat(g.ASNPath)
if err != nil {
log.Warnf("failed to stat ASN database: %s", err)
continue
}
g.lastModTime[g.ASNPath] = info.ModTime()
}
if cityLastModTime, ok := g.lastModTime[g.CityPath]; ok {
info, err := os.Stat(g.CityPath)
if err != nil {
log.Warnf("failed to stat city database: %s", err)
continue
}
if info.ModTime().After(cityLastModTime) {
log.Infof("City database has been updated, reloading")

// Check City database
if g.CityPath != "" {
if updated := g.checkAndUpdateModTime(g.CityPath, "City"); updated {
shouldUpdate = true
g.lastModTime[g.CityPath] = info.ModTime()
}
} else {
info, err := os.Stat(g.CityPath)
if err != nil {
log.Warnf("failed to stat city database: %s", err)
continue
}
g.lastModTime[g.CityPath] = info.ModTime()
}

if shouldUpdate {
if err := g.reload(); err != nil {
log.Warnf("failed to reload databases: %s", err)
Expand All @@ -211,3 +236,35 @@ func (g *GeoDatabase) WatchFiles(ctx context.Context) {
}
}
}

// checkAndUpdateModTime checks if a database file has been modified and updates the lastModTime
// Returns true if the file was updated (needs reload), false otherwise
func (g *GeoDatabase) checkAndUpdateModTime(path, dbType string) bool {
info, err := os.Stat(path)
if err != nil {
log.Warnf("failed to stat %s database: %s", dbType, err)
return false
}

g.RLock()
lastModTime, exists := g.lastModTime[path]
g.RUnlock()

if !exists {
// First time checking this file, just record the mod time
g.Lock()
g.lastModTime[path] = info.ModTime()
g.Unlock()
return false
}

if info.ModTime().After(lastModTime) {
log.Infof("%s database has been updated, reloading", dbType)
g.Lock()
g.lastModTime[path] = info.ModTime()
g.Unlock()
return true
}

return false
}
20 changes: 13 additions & 7 deletions pkg/spoa/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,9 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio
return remediation.Allow, "" // Safe default
}

// If no IP-specific remediation, check country-based
if r < remediation.Unknown && s.geoDatabase.IsValid() {
// Always try to get and set ISO code if geo database is available
// This allows upstream services to use the ISO code regardless of remediation status
if s.geoDatabase.IsValid() {
record, err := s.geoDatabase.GetCity(ip)
if err != nil && !errors.Is(err, geo.ErrNotValidConfig) {
s.logger.WithFields(log.Fields{
Expand All @@ -640,12 +641,17 @@ func (s *Spoa) getIPRemediation(req *request.Request, ip netip.Addr) (remediatio
} else if record != nil {
iso := geo.GetIsoCodeFromRecord(record)
if iso != "" {
cnR, cnOrigin := s.dataset.CheckCN(iso)
if cnR > remediation.Unknown {
r = cnR
origin = cnOrigin
}
// Always set the ISO code variable when available
req.Actions.SetVar(action.ScopeTransaction, "isocode", iso)

// If no IP-specific remediation, check country-based remediation
if r < remediation.Unknown {
cnR, cnOrigin := s.dataset.CheckCN(iso)
if cnR > remediation.Unknown {
r = cnR
origin = cnOrigin
}
}
}
}
}
Expand Down
Loading