Loading django/http/request.py +33 −21 Original line number Diff line number Diff line Loading @@ -4,7 +4,6 @@ import copy import os import re import sys import warnings from io import BytesIO from pprint import pformat try: Loading Loading @@ -66,11 +65,14 @@ class HttpRequest(object): host = '%s:%s' % (host, server_port) allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS if validate_host(host, allowed_hosts): domain, port = split_domain_port(host) if domain and validate_host(domain, allowed_hosts): return host else: raise SuspiciousOperation( "Invalid HTTP_HOST header (you may need to set ALLOWED_HOSTS): %s" % host) msg = "Invalid HTTP_HOST header: %r." % host if domain: msg += "You may need to add %r to ALLOWED_HOSTS." % domain raise SuspiciousOperation(msg) def get_full_path(self): # RFC 3986 requires query string arguments to be in the ASCII range. Loading Loading @@ -454,9 +456,30 @@ def bytes_to_text(s, encoding): return s def split_domain_port(host): """ Return a (domain, port) tuple from a given host. Returned domain is lower-cased. If the host is invalid, the domain will be empty. """ host = host.lower() if not host_validation_re.match(host): return '', '' if host[-1] == ']': # It's an IPv6 address without a port. return host, '' bits = host.rsplit(':', 1) if len(bits) == 2: return tuple(bits) return bits[0], '' def validate_host(host, allowed_hosts): """ Validate the given host header value for this site. Validate the given host for this site. Check that the host looks valid and matches a host or host pattern in the given list of ``allowed_hosts``. Any pattern beginning with a period Loading @@ -464,31 +487,20 @@ def validate_host(host, allowed_hosts): ``example.com`` and any subdomain), ``*`` matches anything, and anything else must match exactly. Note: This function assumes that the given host is lower-cased and has already had the port, if any, stripped off. Return ``True`` for a valid host, ``False`` otherwise. """ # All validation is case-insensitive host = host.lower() # Basic sanity check if not host_validation_re.match(host): return False # Validate only the domain part. if host[-1] == ']': # It's an IPv6 address without a port. domain = host else: domain = host.rsplit(':', 1)[0] for pattern in allowed_hosts: pattern = pattern.lower() match = ( pattern == '*' or pattern.startswith('.') and ( domain.endswith(pattern) or domain == pattern[1:] host.endswith(pattern) or host == pattern[1:] ) or pattern == domain pattern == host ) if match: return True Loading tests/requests/tests.py +54 −4 Original line number Diff line number Diff line Loading @@ -11,16 +11,16 @@ from django.core import signals from django.core.exceptions import SuspiciousOperation from django.core.handlers.wsgi import WSGIRequest, LimitedStream from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError from django.test import TransactionTestCase from django.test import SimpleTestCase, TransactionTestCase from django.test.client import FakePayload from django.test.utils import override_settings, str_prefix from django.utils import six from django.utils import unittest from django.utils.unittest import skipIf from django.utils.http import cookie_date, urlencode from django.utils.timezone import utc class RequestsTests(unittest.TestCase): class RequestsTests(SimpleTestCase): def test_httprequest(self): request = HttpRequest() self.assertEqual(list(request.GET.keys()), []) Loading Loading @@ -287,6 +287,56 @@ class RequestsTests(unittest.TestCase): self.assertEqual(request.get_host(), 'example.com') @override_settings(ALLOWED_HOSTS=[]) def test_get_host_suggestion_of_allowed_host(self): """get_host() makes helpful suggestions if a valid-looking host is not in ALLOWED_HOSTS.""" msg_invalid_host = "Invalid HTTP_HOST header: %r." msg_suggestion = msg_invalid_host + "You may need to add %r to ALLOWED_HOSTS." for host in [ # Valid-looking hosts 'example.com', '12.34.56.78', '[2001:19f0:feee::dead:beef:cafe]', 'xn--4ca9at.com', # Punnycode for öäü.com ]: request = HttpRequest() request.META = {'HTTP_HOST': host} self.assertRaisesMessage( SuspiciousOperation, msg_suggestion % (host, host), request.get_host ) for domain, port in [ # Valid-looking hosts with a port number ('example.com', 80), ('12.34.56.78', 443), ('[2001:19f0:feee::dead:beef:cafe]', 8080), ]: host = '%s:%s' % (domain, port) request = HttpRequest() request.META = {'HTTP_HOST': host} self.assertRaisesMessage( SuspiciousOperation, msg_suggestion % (host, domain), request.get_host ) for host in [ # Invalid hosts 'example.com@evil.tld', 'example.com:dr.frankenstein@evil.tld', 'example.com:dr.frankenstein@evil.tld:80', 'example.com:80/badpath', 'example.com: recovermypassword.com', ]: request = HttpRequest() request.META = {'HTTP_HOST': host} self.assertRaisesMessage( SuspiciousOperation, msg_invalid_host % host, request.get_host ) def test_near_expiration(self): "Cookie will expire when an near expiration time is provided" response = HttpResponse() Loading Loading @@ -587,7 +637,7 @@ class RequestsTests(unittest.TestCase): request.body @unittest.skipIf(connection.vendor == 'sqlite' @skipIf(connection.vendor == 'sqlite' and connection.settings_dict['NAME'] in ('', ':memory:'), "Cannot establish two connections to an in-memory SQLite database.") class DatabaseConnectionHandlingTests(TransactionTestCase): Loading Loading
django/http/request.py +33 −21 Original line number Diff line number Diff line Loading @@ -4,7 +4,6 @@ import copy import os import re import sys import warnings from io import BytesIO from pprint import pformat try: Loading Loading @@ -66,11 +65,14 @@ class HttpRequest(object): host = '%s:%s' % (host, server_port) allowed_hosts = ['*'] if settings.DEBUG else settings.ALLOWED_HOSTS if validate_host(host, allowed_hosts): domain, port = split_domain_port(host) if domain and validate_host(domain, allowed_hosts): return host else: raise SuspiciousOperation( "Invalid HTTP_HOST header (you may need to set ALLOWED_HOSTS): %s" % host) msg = "Invalid HTTP_HOST header: %r." % host if domain: msg += "You may need to add %r to ALLOWED_HOSTS." % domain raise SuspiciousOperation(msg) def get_full_path(self): # RFC 3986 requires query string arguments to be in the ASCII range. Loading Loading @@ -454,9 +456,30 @@ def bytes_to_text(s, encoding): return s def split_domain_port(host): """ Return a (domain, port) tuple from a given host. Returned domain is lower-cased. If the host is invalid, the domain will be empty. """ host = host.lower() if not host_validation_re.match(host): return '', '' if host[-1] == ']': # It's an IPv6 address without a port. return host, '' bits = host.rsplit(':', 1) if len(bits) == 2: return tuple(bits) return bits[0], '' def validate_host(host, allowed_hosts): """ Validate the given host header value for this site. Validate the given host for this site. Check that the host looks valid and matches a host or host pattern in the given list of ``allowed_hosts``. Any pattern beginning with a period Loading @@ -464,31 +487,20 @@ def validate_host(host, allowed_hosts): ``example.com`` and any subdomain), ``*`` matches anything, and anything else must match exactly. Note: This function assumes that the given host is lower-cased and has already had the port, if any, stripped off. Return ``True`` for a valid host, ``False`` otherwise. """ # All validation is case-insensitive host = host.lower() # Basic sanity check if not host_validation_re.match(host): return False # Validate only the domain part. if host[-1] == ']': # It's an IPv6 address without a port. domain = host else: domain = host.rsplit(':', 1)[0] for pattern in allowed_hosts: pattern = pattern.lower() match = ( pattern == '*' or pattern.startswith('.') and ( domain.endswith(pattern) or domain == pattern[1:] host.endswith(pattern) or host == pattern[1:] ) or pattern == domain pattern == host ) if match: return True Loading
tests/requests/tests.py +54 −4 Original line number Diff line number Diff line Loading @@ -11,16 +11,16 @@ from django.core import signals from django.core.exceptions import SuspiciousOperation from django.core.handlers.wsgi import WSGIRequest, LimitedStream from django.http import HttpRequest, HttpResponse, parse_cookie, build_request_repr, UnreadablePostError from django.test import TransactionTestCase from django.test import SimpleTestCase, TransactionTestCase from django.test.client import FakePayload from django.test.utils import override_settings, str_prefix from django.utils import six from django.utils import unittest from django.utils.unittest import skipIf from django.utils.http import cookie_date, urlencode from django.utils.timezone import utc class RequestsTests(unittest.TestCase): class RequestsTests(SimpleTestCase): def test_httprequest(self): request = HttpRequest() self.assertEqual(list(request.GET.keys()), []) Loading Loading @@ -287,6 +287,56 @@ class RequestsTests(unittest.TestCase): self.assertEqual(request.get_host(), 'example.com') @override_settings(ALLOWED_HOSTS=[]) def test_get_host_suggestion_of_allowed_host(self): """get_host() makes helpful suggestions if a valid-looking host is not in ALLOWED_HOSTS.""" msg_invalid_host = "Invalid HTTP_HOST header: %r." msg_suggestion = msg_invalid_host + "You may need to add %r to ALLOWED_HOSTS." for host in [ # Valid-looking hosts 'example.com', '12.34.56.78', '[2001:19f0:feee::dead:beef:cafe]', 'xn--4ca9at.com', # Punnycode for öäü.com ]: request = HttpRequest() request.META = {'HTTP_HOST': host} self.assertRaisesMessage( SuspiciousOperation, msg_suggestion % (host, host), request.get_host ) for domain, port in [ # Valid-looking hosts with a port number ('example.com', 80), ('12.34.56.78', 443), ('[2001:19f0:feee::dead:beef:cafe]', 8080), ]: host = '%s:%s' % (domain, port) request = HttpRequest() request.META = {'HTTP_HOST': host} self.assertRaisesMessage( SuspiciousOperation, msg_suggestion % (host, domain), request.get_host ) for host in [ # Invalid hosts 'example.com@evil.tld', 'example.com:dr.frankenstein@evil.tld', 'example.com:dr.frankenstein@evil.tld:80', 'example.com:80/badpath', 'example.com: recovermypassword.com', ]: request = HttpRequest() request.META = {'HTTP_HOST': host} self.assertRaisesMessage( SuspiciousOperation, msg_invalid_host % host, request.get_host ) def test_near_expiration(self): "Cookie will expire when an near expiration time is provided" response = HttpResponse() Loading Loading @@ -587,7 +637,7 @@ class RequestsTests(unittest.TestCase): request.body @unittest.skipIf(connection.vendor == 'sqlite' @skipIf(connection.vendor == 'sqlite' and connection.settings_dict['NAME'] in ('', ':memory:'), "Cannot establish two connections to an in-memory SQLite database.") class DatabaseConnectionHandlingTests(TransactionTestCase): Loading