Commit 263b3d2b authored by Dmitry Dygalo's avatar Dmitry Dygalo Committed by Tim Graham
Browse files

Fixed #25666 -- Fixed the exact lookup of ArrayField.

parent b8f78823
Loading
Loading
Loading
Loading
+12 −3
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import Field, IntegerField, Transform
from django.db.models.lookups import Exact
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _

@@ -166,7 +167,7 @@ class ArrayField(Field):
class ArrayContains(lookups.DataContains):
    def as_sql(self, qn, connection):
        sql, params = super(ArrayContains, self).as_sql(qn, connection)
        sql += '::%s' % self.lhs.output_field.db_type(connection)
        sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
        return sql, params


@@ -174,7 +175,15 @@ class ArrayContains(lookups.DataContains):
class ArrayContainedBy(lookups.ContainedBy):
    def as_sql(self, qn, connection):
        sql, params = super(ArrayContainedBy, self).as_sql(qn, connection)
        sql += '::%s' % self.lhs.output_field.db_type(connection)
        sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
        return sql, params


@ArrayField.register_lookup
class ArrayExact(Exact):
    def as_sql(self, qn, connection):
        sql, params = super(ArrayExact, self).as_sql(qn, connection)
        sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
        return sql, params


@@ -182,7 +191,7 @@ class ArrayContainedBy(lookups.ContainedBy):
class ArrayOverlap(lookups.Overlap):
    def as_sql(self, qn, connection):
        sql, params = super(ArrayOverlap, self).as_sql(qn, connection)
        sql += '::%s' % self.lhs.output_field.db_type(connection)
        sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
        return sql, params


+2 −0
Original line number Diff line number Diff line
@@ -34,3 +34,5 @@ Bugfixes
* Fixed serialization of
  :class:`~django.contrib.postgres.fields.DateRangeField` and
  :class:`~django.contrib.postgres.fields.DateTimeRangeField` (:ticket:`24937`).

* Fixed the exact lookup of ``ArrayField`` (:ticket:`25666`).
+81 −0
Original line number Diff line number Diff line
@@ -122,6 +122,20 @@ class TestQuerying(PostgreSQLTestCase):
            self.objs[:1]
        )

    def test_exact_charfield(self):
        instance = CharArrayModel.objects.create(field=['text'])
        self.assertSequenceEqual(
            CharArrayModel.objects.filter(field=['text']),
            [instance]
        )

    def test_exact_nested(self):
        instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
        self.assertSequenceEqual(
            NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]),
            [instance]
        )

    def test_isnull(self):
        self.assertSequenceEqual(
            NullableIntegerArrayModel.objects.filter(field__isnull=True),
@@ -244,6 +258,73 @@ class TestQuerying(PostgreSQLTestCase):
        )


class TestDateTimeExactQuerying(PostgreSQLTestCase):

    def setUp(self):
        now = timezone.now()
        self.datetimes = [now]
        self.dates = [now.date()]
        self.times = [now.time()]
        self.objs = [
            DateTimeArrayModel.objects.create(
                datetimes=self.datetimes,
                dates=self.dates,
                times=self.times,
            )
        ]

    def test_exact_datetimes(self):
        self.assertSequenceEqual(
            DateTimeArrayModel.objects.filter(datetimes=self.datetimes),
            self.objs
        )

    def test_exact_dates(self):
        self.assertSequenceEqual(
            DateTimeArrayModel.objects.filter(dates=self.dates),
            self.objs
        )

    def test_exact_times(self):
        self.assertSequenceEqual(
            DateTimeArrayModel.objects.filter(times=self.times),
            self.objs
        )


class TestOtherTypesExactQuerying(PostgreSQLTestCase):

    def setUp(self):
        self.ips = ['192.168.0.1', '::1']
        self.uuids = [uuid.uuid4()]
        self.decimals = [decimal.Decimal(1.25), 1.75]
        self.objs = [
            OtherTypesArrayModel.objects.create(
                ips=self.ips,
                uuids=self.uuids,
                decimals=self.decimals,
            )
        ]

    def test_exact_ip_addresses(self):
        self.assertSequenceEqual(
            OtherTypesArrayModel.objects.filter(ips=self.ips),
            self.objs
        )

    def test_exact_uuids(self):
        self.assertSequenceEqual(
            OtherTypesArrayModel.objects.filter(uuids=self.uuids),
            self.objs
        )

    def test_exact_decimals(self):
        self.assertSequenceEqual(
            OtherTypesArrayModel.objects.filter(decimals=self.decimals),
            self.objs
        )


class TestChecks(PostgreSQLTestCase):

    def test_field_checks(self):