Commit 2495023a authored by Fernando Miranda's avatar Fernando Miranda Committed by Tim Graham
Browse files

Fixed #25143 -- Added ArrayField.from_db_value().

Thanks Karan Lyons for contributing to the patch.
parent f8d20da0
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -28,6 +28,10 @@ class ArrayField(Field):
        if self.size:
            self.default_validators = self.default_validators[:]
            self.default_validators.append(ArrayMaxLengthValidator(self.size))
        # For performance, only add a from_db_value() method if the base field
        # implements it.
        if hasattr(self.base_field, 'from_db_value'):
            self.from_db_value = self._from_db_value
        super(ArrayField, self).__init__(**kwargs)

    @property
@@ -100,6 +104,14 @@ class ArrayField(Field):
            value = [self.base_field.to_python(val) for val in vals]
        return value

    def _from_db_value(self, value, expression, connection, context):
        if value is None:
            return value
        return [
            self.base_field.from_db_value(item, expression, connection, context)
            for item in value
        ]

    def value_to_string(self, obj):
        values = []
        vals = self.value_from_object(obj)
+2 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.db import migrations, models

from ..fields import *  # NOQA
from ..models import TagField


class Migration(migrations.Migration):
@@ -55,6 +56,7 @@ class Migration(migrations.Migration):
                ('ips', ArrayField(models.GenericIPAddressField(), size=None)),
                ('uuids', ArrayField(models.UUIDField(), size=None)),
                ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
                ('tags', ArrayField(TagField(), blank=True, null=True, size=None)),
            ],
            options={
                'required_db_vendor': 'postgresql',
+30 −0
Original line number Diff line number Diff line
@@ -6,6 +6,35 @@ from .fields import (
)


class Tag(object):
    def __init__(self, tag_id):
        self.tag_id = tag_id

    def __eq__(self, other):
        return isinstance(other, Tag) and self.tag_id == other.tag_id


class TagField(models.SmallIntegerField):

    def from_db_value(self, value, expression, connection, context):
        if value is None:
            return value
        return Tag(int(value))

    def to_python(self, value):
        if isinstance(value, Tag):
            return value
        if value is None:
            return value
        return Tag(int(value))

    def get_prep_value(self, value):
        return value.tag_id

    def get_db_prep_value(self, value, connection, prepared=False):
        return self.get_prep_value(value)


class PostgreSQLModel(models.Model):
    class Meta:
        abstract = True
@@ -38,6 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel):
    ips = ArrayField(models.GenericIPAddressField())
    uuids = ArrayField(models.UUIDField())
    decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
    tags = ArrayField(TagField(), blank=True, null=True)


class HStoreModel(PostgreSQLModel):
+21 −1
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ from . import PostgreSQLTestCase
from .models import (
    ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
    NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
    PostgreSQLModel,
    PostgreSQLModel, Tag,
)

try:
@@ -92,12 +92,24 @@ class TestSaveLoad(PostgreSQLTestCase):
            ips=['192.168.0.1', '::1'],
            uuids=[uuid.uuid4()],
            decimals=[decimal.Decimal(1.25), 1.75],
            tags=[Tag(1), Tag(2), Tag(3)],
        )
        instance.save()
        loaded = OtherTypesArrayModel.objects.get()
        self.assertEqual(instance.ips, loaded.ips)
        self.assertEqual(instance.uuids, loaded.uuids)
        self.assertEqual(instance.decimals, loaded.decimals)
        self.assertEqual(instance.tags, loaded.tags)

    def test_null_from_db_value_handling(self):
        instance = OtherTypesArrayModel.objects.create(
            ips=['192.168.0.1', '::1'],
            uuids=[uuid.uuid4()],
            decimals=[decimal.Decimal(1.25), 1.75],
            tags=None,
        )
        instance.refresh_from_db()
        self.assertIsNone(instance.tags)

    def test_model_set_on_base_field(self):
        instance = IntegerArrayModel()
@@ -306,11 +318,13 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase):
        self.ips = ['192.168.0.1', '::1']
        self.uuids = [uuid.uuid4()]
        self.decimals = [decimal.Decimal(1.25), 1.75]
        self.tags = [Tag(1), Tag(2), Tag(3)]
        self.objs = [
            OtherTypesArrayModel.objects.create(
                ips=self.ips,
                uuids=self.uuids,
                decimals=self.decimals,
                tags=self.tags,
            )
        ]

@@ -332,6 +346,12 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase):
            self.objs
        )

    def test_exact_tags(self):
        self.assertSequenceEqual(
            OtherTypesArrayModel.objects.filter(tags=self.tags),
            self.objs
        )


@isolate_apps('postgres_tests')
class TestChecks(PostgreSQLTestCase):