4
4
5
5
from django .contrib .postgres .fields import ArrayField
6
6
from django .db import models
7
+ from django .utils .timezone import now
7
8
from django .utils .translation import gettext as _
8
9
from django_countries .fields import CountryField
10
+ from geopy import distance
9
11
from rest_framework .serializers import BaseSerializer
10
12
13
+ from authentik .events .context_processors .geoip import GeoIPDict
14
+ from authentik .events .models import Event , EventAction
11
15
from authentik .policies .exceptions import PolicyException
12
16
from authentik .policies .geoip .exceptions import GeoIPNotFoundException
13
17
from authentik .policies .models import Policy
14
18
from authentik .policies .types import PolicyRequest , PolicyResult
15
19
20
+ MAX_DISTANCE_HOUR_KM = 1000
21
+
16
22
17
23
class GeoIPPolicy (Policy ):
18
24
"""Ensure the user satisfies requirements of geography or network topology, based on IP
@@ -21,6 +27,15 @@ class GeoIPPolicy(Policy):
21
27
asns = ArrayField (models .IntegerField (), blank = True , default = list )
22
28
countries = CountryField (multiple = True , blank = True )
23
29
30
+ distance_tolerance_km = models .PositiveIntegerField (default = 50 )
31
+
32
+ check_history_distance = models .BooleanField (default = False )
33
+ history_max_distance_km = models .PositiveBigIntegerField (default = 100 )
34
+ history_login_count = models .PositiveIntegerField (default = 5 )
35
+
36
+ check_impossible_travel = models .BooleanField (default = False )
37
+ impossible_tolerance_km = models .PositiveIntegerField (default = 100 )
38
+
24
39
@property
25
40
def serializer (self ) -> type [BaseSerializer ]:
26
41
from authentik .policies .geoip .api import GeoIPPolicySerializer
@@ -37,21 +52,27 @@ def passes(self, request: PolicyRequest) -> PolicyResult:
37
52
- the client IP is advertised by an autonomous system with ASN in the `asns`
38
53
- the client IP is geolocated in a country of `countries`
39
54
"""
40
- results : list [PolicyResult ] = []
55
+ static_results : list [PolicyResult ] = []
56
+ dynamic_results : list [PolicyResult ] = []
41
57
42
58
if self .asns :
43
- results .append (self .passes_asn (request ))
59
+ static_results .append (self .passes_asn (request ))
44
60
if self .countries :
45
- results .append (self .passes_country (request ))
61
+ static_results .append (self .passes_country (request ))
46
62
47
- if not results :
63
+ if self .check_history_distance or self .check_impossible_travel :
64
+ dynamic_results .append (self .passes_distance (request ))
65
+
66
+ if not static_results and not dynamic_results :
48
67
return PolicyResult (True )
49
68
50
- passing = any (r .passing for r in results )
51
- messages = chain (* [r .messages for r in results ])
69
+ passing = any (r .passing for r in static_results ) and all (r .passing for r in dynamic_results )
70
+ messages = chain (
71
+ * [r .messages for r in static_results ], * [r .messages for r in dynamic_results ]
72
+ )
52
73
53
74
result = PolicyResult (passing , * messages )
54
- result .source_results = results
75
+ result .source_results = list ( chain ( static_results , dynamic_results ))
55
76
56
77
return result
57
78
@@ -73,7 +94,7 @@ def passes_asn(self, request: PolicyRequest) -> PolicyResult:
73
94
74
95
def passes_country (self , request : PolicyRequest ) -> PolicyResult :
75
96
# This is not a single get chain because `request.context` can contain `{ "geoip": None }`.
76
- geoip_data = request .context .get ("geoip" )
97
+ geoip_data : GeoIPDict | None = request .context .get ("geoip" )
77
98
country = geoip_data .get ("country" ) if geoip_data else None
78
99
79
100
if not country :
@@ -87,6 +108,42 @@ def passes_country(self, request: PolicyRequest) -> PolicyResult:
87
108
88
109
return PolicyResult (True )
89
110
111
+ def passes_distance (self , request : PolicyRequest ) -> PolicyResult :
112
+ """Check if current policy execution is out of distance range compared
113
+ to previous authentication requests"""
114
+ # Get previous login event and GeoIP data
115
+ previous_logins = Event .objects .filter (
116
+ action = EventAction .LOGIN , user__pk = request .user .pk , context__geo__isnull = False
117
+ ).order_by ("-created" )[: self .history_login_count ]
118
+ _now = now ()
119
+ geoip_data : GeoIPDict | None = request .context .get ("geoip" )
120
+ if not geoip_data :
121
+ return PolicyResult (False )
122
+ for previous_login in previous_logins :
123
+ previous_login_geoip : GeoIPDict = previous_login .context ["geo" ]
124
+
125
+ # Figure out distance
126
+ dist = distance .geodesic (
127
+ (previous_login_geoip ["lat" ], previous_login_geoip ["long" ]),
128
+ (geoip_data ["lat" ], geoip_data ["long" ]),
129
+ )
130
+ if self .check_history_distance and dist .km >= (
131
+ self .history_max_distance_km - self .distance_tolerance_km
132
+ ):
133
+ return PolicyResult (
134
+ False , _ ("Distance from previous authentication is larger than threshold." )
135
+ )
136
+ # Check if distance between `previous_login` and now is more
137
+ # than max distance per hour times the amount of hours since the previous login
138
+ # (round down to the lowest closest time of hours)
139
+ # clamped to be at least 1 hour
140
+ rel_time_hours = max (int ((_now - previous_login .created ).total_seconds () / 3600 ), 1 )
141
+ if self .check_impossible_travel and dist .km >= (
142
+ (MAX_DISTANCE_HOUR_KM * rel_time_hours ) - self .distance_tolerance_km
143
+ ):
144
+ return PolicyResult (False , _ ("Distance is further than possible." ))
145
+ return PolicyResult (True )
146
+
90
147
class Meta (Policy .PolicyMeta ):
91
148
verbose_name = _ ("GeoIP Policy" )
92
149
verbose_name_plural = _ ("GeoIP Policies" )
0 commit comments