Loading django/contrib/postgres/fields/array.py +12 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,10 @@ class ArrayField(Field): if self.size: self.default_validators = self.default_validators[:] self.default_validators.append(ArrayMaxLengthValidator(self.size)) # For performance, only add a from_db_value() method if the base field # implements it. if hasattr(self.base_field, 'from_db_value'): self.from_db_value = self._from_db_value super(ArrayField, self).__init__(**kwargs) @property Loading Loading @@ -100,6 +104,14 @@ class ArrayField(Field): value = [self.base_field.to_python(val) for val in vals] return value def _from_db_value(self, value, expression, connection, context): if value is None: return value return [ self.base_field.from_db_value(item, expression, connection, context) for item in value ] def value_to_string(self, obj): values = [] vals = self.value_from_object(obj) Loading tests/postgres_tests/migrations/0002_create_test_models.py +2 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.db import migrations, models from ..fields import * # NOQA from ..models import TagField class Migration(migrations.Migration): Loading Loading @@ -55,6 +56,7 @@ class Migration(migrations.Migration): ('ips', ArrayField(models.GenericIPAddressField(), size=None)), ('uuids', ArrayField(models.UUIDField(), size=None)), ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), ], options={ 'required_db_vendor': 'postgresql', Loading tests/postgres_tests/models.py +30 −0 Original line number Diff line number Diff line Loading @@ -6,6 +6,35 @@ from .fields import ( ) class Tag(object): def __init__(self, tag_id): self.tag_id = tag_id def __eq__(self, other): return isinstance(other, Tag) and self.tag_id == other.tag_id class TagField(models.SmallIntegerField): def from_db_value(self, value, expression, connection, context): if value is None: return value return Tag(int(value)) def to_python(self, value): if isinstance(value, Tag): return value if value is None: return value return Tag(int(value)) def get_prep_value(self, value): return value.tag_id def get_db_prep_value(self, value, connection, prepared=False): return self.get_prep_value(value) class PostgreSQLModel(models.Model): class Meta: abstract = True Loading Loading @@ -38,6 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel): ips = ArrayField(models.GenericIPAddressField()) uuids = ArrayField(models.UUIDField()) decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) tags = ArrayField(TagField(), blank=True, null=True) class HStoreModel(PostgreSQLModel): Loading tests/postgres_tests/test_array.py +21 −1 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ from . import PostgreSQLTestCase from .models import ( ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, PostgreSQLModel, PostgreSQLModel, Tag, ) try: Loading Loading @@ -92,12 +92,24 @@ class TestSaveLoad(PostgreSQLTestCase): ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=[Tag(1), Tag(2), Tag(3)], ) instance.save() loaded = OtherTypesArrayModel.objects.get() self.assertEqual(instance.ips, loaded.ips) self.assertEqual(instance.uuids, loaded.uuids) self.assertEqual(instance.decimals, loaded.decimals) self.assertEqual(instance.tags, loaded.tags) def test_null_from_db_value_handling(self): instance = OtherTypesArrayModel.objects.create( ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=None, ) instance.refresh_from_db() self.assertIsNone(instance.tags) def test_model_set_on_base_field(self): instance = IntegerArrayModel() Loading Loading @@ -306,11 +318,13 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): self.ips = ['192.168.0.1', '::1'] self.uuids = [uuid.uuid4()] self.decimals = [decimal.Decimal(1.25), 1.75] self.tags = [Tag(1), Tag(2), Tag(3)] self.objs = [ OtherTypesArrayModel.objects.create( ips=self.ips, uuids=self.uuids, decimals=self.decimals, tags=self.tags, ) ] Loading @@ -332,6 +346,12 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): self.objs ) def test_exact_tags(self): self.assertSequenceEqual( OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs ) @isolate_apps('postgres_tests') class TestChecks(PostgreSQLTestCase): Loading Loading
django/contrib/postgres/fields/array.py +12 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,10 @@ class ArrayField(Field): if self.size: self.default_validators = self.default_validators[:] self.default_validators.append(ArrayMaxLengthValidator(self.size)) # For performance, only add a from_db_value() method if the base field # implements it. if hasattr(self.base_field, 'from_db_value'): self.from_db_value = self._from_db_value super(ArrayField, self).__init__(**kwargs) @property Loading Loading @@ -100,6 +104,14 @@ class ArrayField(Field): value = [self.base_field.to_python(val) for val in vals] return value def _from_db_value(self, value, expression, connection, context): if value is None: return value return [ self.base_field.from_db_value(item, expression, connection, context) for item in value ] def value_to_string(self, obj): values = [] vals = self.value_from_object(obj) Loading
tests/postgres_tests/migrations/0002_create_test_models.py +2 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.db import migrations, models from ..fields import * # NOQA from ..models import TagField class Migration(migrations.Migration): Loading Loading @@ -55,6 +56,7 @@ class Migration(migrations.Migration): ('ips', ArrayField(models.GenericIPAddressField(), size=None)), ('uuids', ArrayField(models.UUIDField(), size=None)), ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), ('tags', ArrayField(TagField(), blank=True, null=True, size=None)), ], options={ 'required_db_vendor': 'postgresql', Loading
tests/postgres_tests/models.py +30 −0 Original line number Diff line number Diff line Loading @@ -6,6 +6,35 @@ from .fields import ( ) class Tag(object): def __init__(self, tag_id): self.tag_id = tag_id def __eq__(self, other): return isinstance(other, Tag) and self.tag_id == other.tag_id class TagField(models.SmallIntegerField): def from_db_value(self, value, expression, connection, context): if value is None: return value return Tag(int(value)) def to_python(self, value): if isinstance(value, Tag): return value if value is None: return value return Tag(int(value)) def get_prep_value(self, value): return value.tag_id def get_db_prep_value(self, value, connection, prepared=False): return self.get_prep_value(value) class PostgreSQLModel(models.Model): class Meta: abstract = True Loading Loading @@ -38,6 +67,7 @@ class OtherTypesArrayModel(PostgreSQLModel): ips = ArrayField(models.GenericIPAddressField()) uuids = ArrayField(models.UUIDField()) decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) tags = ArrayField(TagField(), blank=True, null=True) class HStoreModel(PostgreSQLModel): Loading
tests/postgres_tests/test_array.py +21 −1 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ from . import PostgreSQLTestCase from .models import ( ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, PostgreSQLModel, PostgreSQLModel, Tag, ) try: Loading Loading @@ -92,12 +92,24 @@ class TestSaveLoad(PostgreSQLTestCase): ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=[Tag(1), Tag(2), Tag(3)], ) instance.save() loaded = OtherTypesArrayModel.objects.get() self.assertEqual(instance.ips, loaded.ips) self.assertEqual(instance.uuids, loaded.uuids) self.assertEqual(instance.decimals, loaded.decimals) self.assertEqual(instance.tags, loaded.tags) def test_null_from_db_value_handling(self): instance = OtherTypesArrayModel.objects.create( ips=['192.168.0.1', '::1'], uuids=[uuid.uuid4()], decimals=[decimal.Decimal(1.25), 1.75], tags=None, ) instance.refresh_from_db() self.assertIsNone(instance.tags) def test_model_set_on_base_field(self): instance = IntegerArrayModel() Loading Loading @@ -306,11 +318,13 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): self.ips = ['192.168.0.1', '::1'] self.uuids = [uuid.uuid4()] self.decimals = [decimal.Decimal(1.25), 1.75] self.tags = [Tag(1), Tag(2), Tag(3)] self.objs = [ OtherTypesArrayModel.objects.create( ips=self.ips, uuids=self.uuids, decimals=self.decimals, tags=self.tags, ) ] Loading @@ -332,6 +346,12 @@ class TestOtherTypesExactQuerying(PostgreSQLTestCase): self.objs ) def test_exact_tags(self): self.assertSequenceEqual( OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs ) @isolate_apps('postgres_tests') class TestChecks(PostgreSQLTestCase): Loading