Commit 48ad2886 authored by Marc Tamlyn's avatar Marc Tamlyn
Browse files

Fixed #24001 -- Added range fields for PostgreSQL.

Added support for PostgreSQL range types to contrib.postgres.

- 5 new model fields
- 4 new form fields
- New validators
- Uses psycopg2's range type implementation in python
parent 916e3880
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from .array import *  # NOQA
from .hstore import *  # NOQA
from .ranges import *  # NOQA
+156 −0
Original line number Diff line number Diff line
import json

from django.contrib.postgres import lookups, forms
from django.db import models
from django.utils import six

from psycopg2.extras import Range, NumericRange, DateRange, DateTimeTZRange


__all__ = [
    'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
    'FloatRangeField', 'DateTimeRangeField', 'DateRangeField',
]


class RangeField(models.Field):
    empty_strings_allowed = False

    def get_prep_value(self, value):
        if value is None:
            return None
        elif isinstance(value, Range):
            return value
        elif isinstance(value, (list, tuple)):
            return self.range_type(value[0], value[1])
        return value

    def to_python(self, value):
        if isinstance(value, six.string_types):
            value = self.range_type(**json.loads(value))
        elif isinstance(value, (list, tuple)):
            value = self.range_type(value[0], value[1])
        return value

    def value_to_string(self, obj):
        value = self._get_val_from_obj(obj)
        if value is None:
            return None
        if value.isempty:
            return json.dumps({"empty": True})
        return json.dumps({
            "lower": value.lower,
            "upper": value.upper,
            "bounds": value._bounds,
        })

    def formfield(self, **kwargs):
        kwargs.setdefault('form_class', self.form_field)
        return super(RangeField, self).formfield(**kwargs)


class IntegerRangeField(RangeField):
    base_field = models.IntegerField()
    range_type = NumericRange
    form_field = forms.IntegerRangeField

    def db_type(self, connection):
        return 'int4range'


class BigIntegerRangeField(RangeField):
    base_field = models.BigIntegerField()
    range_type = NumericRange
    form_field = forms.IntegerRangeField

    def db_type(self, connection):
        return 'int8range'


class FloatRangeField(RangeField):
    base_field = models.FloatField()
    range_type = NumericRange
    form_field = forms.FloatRangeField

    def db_type(self, connection):
        return 'numrange'


class DateTimeRangeField(RangeField):
    base_field = models.DateTimeField()
    range_type = DateTimeTZRange
    form_field = forms.DateTimeRangeField

    def db_type(self, connection):
        return 'tstzrange'


class DateRangeField(RangeField):
    base_field = models.DateField()
    range_type = DateRange
    form_field = forms.DateRangeField

    def db_type(self, connection):
        return 'daterange'


RangeField.register_lookup(lookups.DataContains)
RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap)


@RangeField.register_lookup
class FullyLessThan(lookups.PostgresSimpleLookup):
    lookup_name = 'fully_lt'
    operator = '<<'


@RangeField.register_lookup
class FullGreaterThan(lookups.PostgresSimpleLookup):
    lookup_name = 'fully_gt'
    operator = '>>'


@RangeField.register_lookup
class NotLessThan(lookups.PostgresSimpleLookup):
    lookup_name = 'not_lt'
    operator = '&>'


@RangeField.register_lookup
class NotGreaterThan(lookups.PostgresSimpleLookup):
    lookup_name = 'not_gt'
    operator = '&<'


@RangeField.register_lookup
class AdjacentToLookup(lookups.PostgresSimpleLookup):
    lookup_name = 'adjacent_to'
    operator = '-|-'


@RangeField.register_lookup
class RangeStartsWith(lookups.FunctionTransform):
    lookup_name = 'startswith'
    function = 'lower'

    @property
    def output_field(self):
        return self.lhs.output_field.base_field


@RangeField.register_lookup
class RangeEndsWith(lookups.FunctionTransform):
    lookup_name = 'endswith'
    function = 'upper'

    @property
    def output_field(self):
        return self.lhs.output_field.base_field


@RangeField.register_lookup
class IsEmpty(lookups.FunctionTransform):
    lookup_name = 'isempty'
    function = 'isempty'
    output_field = models.BooleanField()
+1 −0
Original line number Diff line number Diff line
from .array import *  # NOQA
from .hstore import *  # NOQA
from .ranges import *  # NOQA
+69 −0
Original line number Diff line number Diff line
from django.core import exceptions
from django import forms
from django.utils.translation import ugettext_lazy as _

from psycopg2.extras import NumericRange, DateRange, DateTimeTZRange


__all__ = ['IntegerRangeField', 'FloatRangeField', 'DateTimeRangeField', 'DateRangeField']


class BaseRangeField(forms.MultiValueField):
    default_error_messages = {
        'invalid': _('Enter two valid values.'),
        'bound_ordering': _('The start of the range must not exceed the end of the range.'),
    }

    def __init__(self, **kwargs):
        widget = forms.MultiWidget([self.base_field.widget, self.base_field.widget])
        kwargs.setdefault('widget', widget)
        kwargs.setdefault('fields', [self.base_field(required=False), self.base_field(required=False)])
        kwargs.setdefault('required', False)
        kwargs.setdefault('require_all_fields', False)
        super(BaseRangeField, self).__init__(**kwargs)

    def prepare_value(self, value):
        if isinstance(value, self.range_type):
            return [value.lower, value.upper]
        if value is None:
            return [None, None]
        return value

    def compress(self, values):
        if not values:
            return None
        lower, upper = values
        if lower is not None and upper is not None and lower > upper:
            raise exceptions.ValidationError(
                self.error_messages['bound_ordering'],
                code='bound_ordering',
            )
        try:
            range_value = self.range_type(lower, upper)
        except TypeError:
            raise exceptions.ValidationError(
                self.error_messages['invalid'],
                code='invalid',
            )
        else:
            return range_value


class IntegerRangeField(BaseRangeField):
    base_field = forms.IntegerField
    range_type = NumericRange


class FloatRangeField(BaseRangeField):
    base_field = forms.FloatField
    range_type = NumericRange


class DateTimeRangeField(BaseRangeField):
    base_field = forms.DateTimeField
    range_type = DateTimeTZRange


class DateRangeField(BaseRangeField):
    base_field = forms.DateField
    range_type = DateRange
+14 −1
Original line number Diff line number Diff line
import copy

from django.core.exceptions import ValidationError
from django.core.validators import MaxLengthValidator, MinLengthValidator
from django.core.validators import (
    MaxLengthValidator, MinLengthValidator, MaxValueValidator,
    MinValueValidator,
)
from django.utils.deconstruct import deconstructible
from django.utils.translation import ungettext_lazy, ugettext_lazy as _

@@ -63,3 +66,13 @@ class KeysValidator(object):

    def __ne__(self, other):
        return not (self == other)


class RangeMaxValueValidator(MaxValueValidator):
    compare = lambda self, a, b: a.upper > b
    message = _('Ensure that this range is completely less than or equal to %(limit_value)s.')


class RangeMinValueValidator(MinValueValidator):
    compare = lambda self, a, b: a.lower < b
    message = _('Ensure that this range is completely greater than or equal to %(limit_value)s.')
Loading