Commit e2a652fa authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Refactored datetime handling in the database cache backend.

Took advantage of the new database adapters and converters.
parent 2a05a823
Loading
Loading
Loading
Loading
+43 −43
Original line number Diff line number Diff line
@@ -4,8 +4,7 @@ from datetime import datetime

from django.conf import settings
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.db import DatabaseError, connections, router, transaction
from django.db.backends.utils import typecast_timestamp
from django.db import DatabaseError, connections, models, router, transaction
from django.utils import six, timezone
from django.utils.encoding import force_bytes

@@ -46,43 +45,40 @@ class BaseDatabaseCache(BaseCache):
class DatabaseCache(BaseDatabaseCache):

    # This class uses cursors provided by the database connection. This means
    # it reads expiration values as aware or naive datetimes depending on the
    # value of USE_TZ. They must be compared to aware or naive representations
    # of "now" respectively.

    # But it bypasses the ORM for write operations. As a consequence, aware
    # datetimes aren't made naive for databases that don't support time zones.
    # We work around this problem by always using naive datetimes when writing
    # expiration values, in UTC when USE_TZ = True and in local time otherwise.
    # it reads expiration values as aware or naive datetimes, depending on the
    # value of USE_TZ and whether the database supports time zones. The ORM's
    # conversion and adaptation infrastructure is then used to avoid comparing
    # aware and naive datetimes accidentally.

    def get(self, key, default=None, version=None):
        key = self.make_key(key, version=version)
        self.validate_key(key)
        db = router.db_for_read(self.cache_model_class)
        table = connections[db].ops.quote_name(self._table)
        connection = connections[db]
        table = connection.ops.quote_name(self._table)

        with connections[db].cursor() as cursor:
        with connection.cursor() as cursor:
            cursor.execute("SELECT cache_key, value, expires FROM %s "
                           "WHERE cache_key = %%s" % table, [key])
            row = cursor.fetchone()
        if row is None:
            return default
        now = timezone.now()

        expires = row[2]
        if connections[db].features.needs_datetime_string_cast and not isinstance(expires, datetime):
            # Note: typecasting is needed by some 3rd party database backends.
            # All core backends work without typecasting, so be careful about
            # changes here - test suite will NOT pick regressions here.
            expires = typecast_timestamp(str(expires))
        if settings.USE_TZ:
            expires = expires.replace(tzinfo=timezone.utc)
        if expires < now:
        expression = models.Expression(output_field=models.DateTimeField())
        for converter in (connection.ops.get_db_converters(expression) +
                          expression.get_db_converters(connection)):
            expires = converter(expires, expression, connection, {})

        if expires < timezone.now():
            db = router.db_for_write(self.cache_model_class)
            with connections[db].cursor() as cursor:
            connection = connections[db]
            with connection.cursor() as cursor:
                cursor.execute("DELETE FROM %s "
                               "WHERE cache_key = %%s" % table, [key])
            return default
        value = connections[db].ops.process_clob(row[1])

        value = connection.ops.process_clob(row[1])
        return pickle.loads(base64.b64decode(force_bytes(value)))

    def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
@@ -98,9 +94,10 @@ class DatabaseCache(BaseDatabaseCache):
    def _base_set(self, mode, key, value, timeout=DEFAULT_TIMEOUT):
        timeout = self.get_backend_timeout(timeout)
        db = router.db_for_write(self.cache_model_class)
        table = connections[db].ops.quote_name(self._table)
        connection = connections[db]
        table = connection.ops.quote_name(self._table)

        with connections[db].cursor() as cursor:
        with connection.cursor() as cursor:
            cursor.execute("SELECT COUNT(*) FROM %s" % table)
            num = cursor.fetchone()[0]
            now = timezone.now()
@@ -129,14 +126,15 @@ class DatabaseCache(BaseDatabaseCache):
                    cursor.execute("SELECT cache_key, expires FROM %s "
                                   "WHERE cache_key = %%s" % table, [key])
                    result = cursor.fetchone()

                    if result:
                        current_expires = result[1]
                        if (connections[db].features.needs_datetime_string_cast and not
                                isinstance(current_expires, datetime)):
                            current_expires = typecast_timestamp(str(current_expires))
                        if settings.USE_TZ:
                            current_expires = current_expires.replace(tzinfo=timezone.utc)
                    exp = connections[db].ops.adapt_datetimefield_value(exp)
                        expression = models.Expression(output_field=models.DateTimeField())
                        for converter in (connection.ops.get_db_converters(expression) +
                                          expression.get_db_converters(connection)):
                            current_expires = converter(current_expires, expression, connection, {})

                    exp = connection.ops.adapt_datetimefield_value(exp)
                    if result and (mode == 'set' or (mode == 'add' and current_expires < now)):
                        cursor.execute("UPDATE %s SET value = %%s, expires = %%s "
                                       "WHERE cache_key = %%s" % table,
@@ -156,9 +154,10 @@ class DatabaseCache(BaseDatabaseCache):
        self.validate_key(key)

        db = router.db_for_write(self.cache_model_class)
        table = connections[db].ops.quote_name(self._table)
        connection = connections[db]
        table = connection.ops.quote_name(self._table)

        with connections[db].cursor() as cursor:
        with connection.cursor() as cursor:
            cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key])

    def has_key(self, key, version=None):
@@ -166,7 +165,8 @@ class DatabaseCache(BaseDatabaseCache):
        self.validate_key(key)

        db = router.db_for_read(self.cache_model_class)
        table = connections[db].ops.quote_name(self._table)
        connection = connections[db]
        table = connection.ops.quote_name(self._table)

        if settings.USE_TZ:
            now = datetime.utcnow()
@@ -174,27 +174,26 @@ class DatabaseCache(BaseDatabaseCache):
            now = datetime.now()
        now = now.replace(microsecond=0)

        with connections[db].cursor() as cursor:
        with connection.cursor() as cursor:
            cursor.execute("SELECT cache_key FROM %s "
                           "WHERE cache_key = %%s and expires > %%s" % table,
                           [key, connections[db].ops.adapt_datetimefield_value(now)])
                           [key, connection.ops.adapt_datetimefield_value(now)])
            return cursor.fetchone() is not None

    def _cull(self, db, cursor, now):
        if self._cull_frequency == 0:
            self.clear()
        else:
            # When USE_TZ is True, 'now' will be an aware datetime in UTC.
            now = now.replace(tzinfo=None)
            table = connections[db].ops.quote_name(self._table)
            connection = connections[db]
            table = connection.ops.quote_name(self._table)
            cursor.execute("DELETE FROM %s WHERE expires < %%s" % table,
                           [connections[db].ops.adapt_datetimefield_value(now)])
                           [connection.ops.adapt_datetimefield_value(now)])
            cursor.execute("SELECT COUNT(*) FROM %s" % table)
            num = cursor.fetchone()[0]
            if num > self._max_entries:
                cull_num = num // self._cull_frequency
                cursor.execute(
                    connections[db].ops.cache_key_culling_sql() % table,
                    connection.ops.cache_key_culling_sql() % table,
                    [cull_num])
                cursor.execute("DELETE FROM %s "
                               "WHERE cache_key < %%s" % table,
@@ -202,6 +201,7 @@ class DatabaseCache(BaseDatabaseCache):

    def clear(self):
        db = router.db_for_write(self.cache_model_class)
        table = connections[db].ops.quote_name(self._table)
        with connections[db].cursor() as cursor:
        connection = connections[db]
        table = connection.ops.quote_name(self._table)
        with connection.cursor() as cursor:
            cursor.execute('DELETE FROM %s' % table)