Skip to content

Commit 3ebae1f

Browse files
ewdurbindi
andauthored
Revert "Revert "models/infra for ip_address tracking and bans"" (#12534)
* Revert "Revert "models/infra for ip_address tracking and bans (#12513)" (#12533)" This reverts commit 09e934c. * Set our own request._unauthenticated_userid * Tests Co-authored-by: Dustin Ingram <[email protected]>
1 parent 09e934c commit 3ebae1f

29 files changed

+843
-55
lines changed

tests/common/db/ip_addresses.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
from warehouse.ip_addresses.models import IpAddress
14+
15+
from .base import WarehouseFactory
16+
17+
18+
class IpAddressFactory(WarehouseFactory):
19+
class Meta:
20+
model = IpAddress

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def pyramid_request(pyramid_services, jinja, remote_addr):
156156
dummy_request.find_service = pyramid_services.find_service
157157
dummy_request.remote_addr = remote_addr
158158
dummy_request.authentication_method = pretend.stub()
159+
dummy_request._unauthenticated_userid = None
159160

160161
dummy_request.registry.registerUtility(jinja, IJinja2Environment, name=".jinja2")
161162

@@ -391,6 +392,7 @@ def query_recorder(app_config):
391392
def db_request(pyramid_request, db_session):
392393
pyramid_request.db = db_session
393394
pyramid_request.flags = admin.flags.Flags(pyramid_request)
395+
pyramid_request.banned = admin.bans.Bans(pyramid_request)
394396
return pyramid_request
395397

396398

tests/unit/accounts/test_core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@ def test_without_identity(self):
321321
assert accounts._user(request) is None
322322

323323

324+
class TestUnauthenticatedUserid:
325+
def test_unauthenticated_userid(self):
326+
request = pretend.stub()
327+
assert accounts._unauthenticated_userid(request) is None
328+
329+
324330
def test_includeme(monkeypatch):
325331
authz_obj = pretend.stub()
326332
authz_cls = pretend.call_recorder(lambda *a, **kw: authz_obj)
@@ -358,7 +364,7 @@ def test_includeme(monkeypatch):
358364
register_service_factory=pretend.call_recorder(
359365
lambda factory, iface, name=None: None
360366
),
361-
add_request_method=pretend.call_recorder(lambda f, name, reify: None),
367+
add_request_method=pretend.call_recorder(lambda f, name, reify=False: None),
362368
set_security_policy=pretend.call_recorder(lambda p: None),
363369
maybe_dotted=pretend.call_recorder(lambda path: path),
364370
add_route_predicate=pretend.call_recorder(lambda name, cls: None),
@@ -389,7 +395,8 @@ def test_includeme(monkeypatch):
389395
pretend.call(RateLimit("3 per 6 hours"), IRateLimiter, name="email.verify"),
390396
]
391397
assert config.add_request_method.calls == [
392-
pretend.call(accounts._user, name="user", reify=True)
398+
pretend.call(accounts._user, name="user", reify=True),
399+
pretend.call(accounts._unauthenticated_userid, name="_unauthenticated_userid"),
393400
]
394401
assert config.set_security_policy.calls == [pretend.call(multi_policy_obj)]
395402
assert multi_policy_cls.calls == [

tests/unit/accounts/test_forms.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ def test_validate_username_with_user(self):
7171
assert user_service.find_userid.calls == [pretend.call("my_username")]
7272

7373
def test_validate_password_no_user(self):
74-
request = pretend.stub()
74+
request = pretend.stub(
75+
remote_addr="1.2.3.4",
76+
banned=pretend.stub(
77+
by_ip=lambda ip_address: False,
78+
),
79+
)
7580
user_service = pretend.stub(
7681
find_userid=pretend.call_recorder(lambda userid: None)
7782
)
@@ -92,7 +97,9 @@ def test_validate_password_no_user(self):
9297
]
9398

9499
def test_validate_password_disabled_for_compromised_pw(self, db_session):
95-
request = pretend.stub()
100+
request = pretend.stub(
101+
remote_addr="1.2.3.4", banned=pretend.stub(by_ip=lambda ip_address: False)
102+
)
96103
user_service = pretend.stub(
97104
find_userid=pretend.call_recorder(lambda userid: 1),
98105
is_disabled=pretend.call_recorder(
@@ -115,7 +122,12 @@ def test_validate_password_disabled_for_compromised_pw(self, db_session):
115122
assert user_service.is_disabled.calls == [pretend.call(1)]
116123

117124
def test_validate_password_ok(self):
118-
request = pretend.stub(remote_addr="1.2.3.4")
125+
request = pretend.stub(
126+
remote_addr="1.2.3.4",
127+
banned=pretend.stub(
128+
by_ip=lambda ip_address: False,
129+
),
130+
)
119131
user_service = pretend.stub(
120132
find_userid=pretend.call_recorder(lambda userid: 1),
121133
check_password=pretend.call_recorder(
@@ -150,7 +162,12 @@ def test_validate_password_ok(self):
150162
]
151163

152164
def test_validate_password_notok(self, db_session):
153-
request = pretend.stub(remote_addr="127.0.0.1")
165+
request = pretend.stub(
166+
remote_addr="1.2.3.4",
167+
banned=pretend.stub(
168+
by_ip=lambda ip_address: False,
169+
),
170+
)
154171
user_service = pretend.stub(
155172
find_userid=pretend.call_recorder(lambda userid: 1),
156173
check_password=pretend.call_recorder(
@@ -186,7 +203,12 @@ def test_validate_password_notok(self, db_session):
186203
]
187204

188205
def test_validate_password_too_many_failed(self):
189-
request = pretend.stub(remote_addr="1.2.3.4")
206+
request = pretend.stub(
207+
remote_addr="1.2.3.4",
208+
banned=pretend.stub(
209+
by_ip=lambda ip_address: False,
210+
),
211+
)
190212
user_service = pretend.stub(
191213
find_userid=pretend.call_recorder(lambda userid: 1),
192214
check_password=pretend.call_recorder(
@@ -218,7 +240,12 @@ def test_password_breached(self, monkeypatch):
218240
monkeypatch.setattr(forms, "send_password_compromised_email_hibp", send_email)
219241

220242
user = pretend.stub(id=1)
221-
request = pretend.stub(remote_addr="1.2.3.4")
243+
request = pretend.stub(
244+
remote_addr="1.2.3.4",
245+
banned=pretend.stub(
246+
by_ip=lambda ip_address: False,
247+
),
248+
)
222249
user_service = pretend.stub(
223250
find_userid=lambda _: 1,
224251
get_user=lambda _: user,
@@ -247,6 +274,72 @@ def test_password_breached(self, monkeypatch):
247274
]
248275
assert send_email.calls == [pretend.call(request, user)]
249276

277+
def test_validate_password_ok_ip_banned(self):
278+
request = pretend.stub(
279+
remote_addr="1.2.3.4",
280+
banned=pretend.stub(
281+
by_ip=lambda ip_address: True,
282+
),
283+
)
284+
user_service = pretend.stub(
285+
find_userid=pretend.call_recorder(lambda userid: 1),
286+
check_password=pretend.call_recorder(
287+
lambda userid, password, tags=None: True
288+
),
289+
is_disabled=pretend.call_recorder(lambda userid: (False, None)),
290+
)
291+
breach_service = pretend.stub(
292+
check_password=pretend.call_recorder(lambda pw, tags: False)
293+
)
294+
form = forms.LoginForm(
295+
data={"username": "my_username"},
296+
request=request,
297+
user_service=user_service,
298+
breach_service=breach_service,
299+
check_password_metrics_tags=["bar"],
300+
)
301+
field = pretend.stub(data="pw")
302+
303+
with pytest.raises(wtforms.validators.ValidationError):
304+
form.validate_password(field)
305+
306+
assert user_service.find_userid.calls == []
307+
assert user_service.is_disabled.calls == []
308+
assert user_service.check_password.calls == []
309+
assert breach_service.check_password.calls == []
310+
311+
def test_validate_password_notok_ip_banned(self, db_session):
312+
request = pretend.stub(
313+
remote_addr="1.2.3.4",
314+
banned=pretend.stub(
315+
by_ip=lambda ip_address: True,
316+
),
317+
)
318+
user_service = pretend.stub(
319+
find_userid=pretend.call_recorder(lambda userid: 1),
320+
check_password=pretend.call_recorder(
321+
lambda userid, password, tags=None: False
322+
),
323+
is_disabled=pretend.call_recorder(lambda userid: (False, None)),
324+
record_event=pretend.call_recorder(lambda *a, **kw: None),
325+
)
326+
breach_service = pretend.stub()
327+
form = forms.LoginForm(
328+
data={"username": "my_username"},
329+
request=request,
330+
user_service=user_service,
331+
breach_service=breach_service,
332+
)
333+
field = pretend.stub(data="pw")
334+
335+
with pytest.raises(wtforms.validators.ValidationError):
336+
form.validate_password(field)
337+
338+
assert user_service.find_userid.calls == []
339+
assert user_service.is_disabled.calls == []
340+
assert user_service.check_password.calls == []
341+
assert user_service.record_event.calls == []
342+
250343

251344
class TestRegistrationForm:
252345
def test_create(self):

tests/unit/accounts/test_models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,22 @@ def test_query_by_email_when_not_primary(self, db_session):
103103
def test_recent_events(self, db_session):
104104
user = DBUserFactory.create()
105105
recent_event = DBUserEventFactory(source=user, tag="foo", ip_address="0.0.0.0")
106+
legacy_event = DBUserEventFactory(
107+
source=user,
108+
tag="wu",
109+
ip_address_string="0.0.0.0",
110+
time=datetime.datetime.now() - datetime.timedelta(days=1),
111+
)
106112
stale_event = DBUserEventFactory(
107113
source=user,
108114
tag="bar",
109115
ip_address="0.0.0.0",
110116
time=datetime.datetime.now() - datetime.timedelta(days=91),
111117
)
112118

113-
assert user.events.all() == [recent_event, stale_event]
114-
assert user.recent_events.all() == [recent_event]
119+
assert user.events.all() == [recent_event, legacy_event, stale_event]
120+
assert user.recent_events.all() == [recent_event, legacy_event]
121+
assert user.recent_events.all()[-1].ip_address == "0.0.0.0"
115122

116123
def test_regular_user_not_prohibited_password_reset(self, db_session):
117124
user = DBUserFactory.create()

0 commit comments

Comments
 (0)