Commit 32c7d3c0 authored by Tim Graham's avatar Tim Graham
Browse files

Fixed #15089 -- Allowed contrib.sites to lookup the current site based on request.get_host().

Thanks Claude Paroz, Riccardo Magliocchetti, and Damian Moore
for contributions to the patch.
parent 33396058
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -72,7 +72,7 @@ def shortcut(request, content_type_id, object_id):
        # Fall back to the current site (if possible).
        if object_domain is None:
            try:
                object_domain = Site.objects.get_current().domain
                object_domain = Site.objects.get_current(request).domain
            except Site.DoesNotExist:
                pass

+2 −2
Original line number Diff line number Diff line
from .models import Site
from .shortcuts import get_current_site


class CurrentSiteMiddleware(object):
@@ -7,4 +7,4 @@ class CurrentSiteMiddleware(object):
    """

    def process_request(self, request):
        request.site = Site.objects.get_current()
        request.site = get_current_site(request)
+34 −17
Original line number Diff line number Diff line
@@ -34,26 +34,39 @@ def _simple_domain_name_validator(value):

class SiteManager(models.Manager):

    def get_current(self):
    def _get_site_by_id(self, site_id):
        if site_id not in SITE_CACHE:
            site = self.get(pk=site_id)
            SITE_CACHE[site_id] = site
        return SITE_CACHE[site_id]

    def _get_site_by_request(self, request):
        host = request.get_host()
        if host not in SITE_CACHE:
            site = self.get(domain__iexact=host)
            SITE_CACHE[host] = site
        return SITE_CACHE[host]

    def get_current(self, request=None):
        """
        Returns the current ``Site`` based on the SITE_ID in the
        project's settings. The ``Site`` object is cached the first
        time it's retrieved from the database.
        Returns the current Site based on the SITE_ID in the project's settings.
        If SITE_ID isn't defined, it returns the site with domain matching
        request.get_host(). The ``Site`` object is cached the first time it's
        retrieved from the database.
        """
        from django.conf import settings
        try:
            sid = settings.SITE_ID
        except AttributeError:
        if getattr(settings, 'SITE_ID', ''):
            site_id = settings.SITE_ID
            return self._get_site_by_id(site_id)
        elif request:
            return self._get_site_by_request(request)

        raise ImproperlyConfigured(
            "You're using the Django \"sites framework\" without having "
            "set the SITE_ID setting. Create a site in your database and "
                "set the SITE_ID setting to fix this error.")
        try:
            current_site = SITE_CACHE[sid]
        except KeyError:
            current_site = self.get(pk=sid)
            SITE_CACHE[sid] = current_site
        return current_site
            "set the SITE_ID setting or pass a request to "
            "Site.objects.get_current() to fix this error."
        )

    def clear_cache(self):
        """Clears the ``Site`` object cache."""
@@ -103,5 +116,9 @@ def clear_site_cache(sender, **kwargs):
        del SITE_CACHE[instance.pk]
    except KeyError:
        pass
    try:
        del SITE_CACHE[Site.objects.get(pk=instance.pk).domain]
    except (KeyError, Site.DoesNotExist):
        pass
pre_save.connect(clear_site_cache, sender=Site)
pre_delete.connect(clear_site_cache, sender=Site)
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ def get_current_site(request):
    # the Site models when django.contrib.sites isn't installed.
    if apps.is_installed('django.contrib.sites'):
        from .models import Site
        return Site.objects.get_current()
        return Site.objects.get_current(request)
    else:
        from .requests import RequestSite
        return RequestSite(request)
+39 −2
Original line number Diff line number Diff line
@@ -5,8 +5,9 @@ from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.http import HttpRequest
from django.test import TestCase, modify_settings, override_settings

from . import models
from .middleware import CurrentSiteMiddleware
from .models import Site
from .models import clear_site_cache, Site
from .requests import RequestSite
from .shortcuts import get_current_site

@@ -15,7 +16,12 @@ from .shortcuts import get_current_site
class SitesFrameworkTests(TestCase):

    def setUp(self):
        Site(id=settings.SITE_ID, domain="example.com", name="example.com").save()
        self.site = Site(
            id=settings.SITE_ID,
            domain="example.com",
            name="example.com",
        )
        self.site.save()

    def test_save_another(self):
        # Regression for #17415
@@ -71,6 +77,17 @@ class SitesFrameworkTests(TestCase):
            self.assertIsInstance(site, RequestSite)
            self.assertEqual(site.name, "example.com")

    @override_settings(SITE_ID='', ALLOWED_HOSTS=['example.com'])
    def test_get_current_site_no_site_id(self):
        request = HttpRequest()
        request.META = {
            "SERVER_NAME": "example.com",
            "SERVER_PORT": "80",
        }
        del settings.SITE_ID
        site = get_current_site(request)
        self.assertEqual(site.name, "example.com")

    def test_domain_name_with_whitespaces(self):
        # Regression for #17320
        # Domain names are not allowed contain whitespace characters
@@ -81,6 +98,26 @@ class SitesFrameworkTests(TestCase):
        site.domain = "test\ntest"
        self.assertRaises(ValidationError, site.full_clean)

    def test_clear_site_cache(self):
        request = HttpRequest()
        request.META = {
            "SERVER_NAME": "example.com",
            "SERVER_PORT": "80",
        }
        self.assertEqual(models.SITE_CACHE, {})
        get_current_site(request)
        expected_cache = {self.site.id: self.site}
        self.assertEqual(models.SITE_CACHE, expected_cache)

        with self.settings(SITE_ID=''):
            get_current_site(request)

        expected_cache.update({self.site.domain: self.site})
        self.assertEqual(models.SITE_CACHE, expected_cache)

        clear_site_cache(Site, instance=self.site)
        self.assertEqual(models.SITE_CACHE, {})


class MiddlewareTest(TestCase):

Loading