Commit 916e3880 authored by Marc Tamlyn's avatar Marc Tamlyn
Browse files

Move % addition to lookups, refactor postgres lookups.

These refactorings making overriding some text based lookup names on
other fields (specifically `contains`) much cleaner. It also removes a
bunch of duplication in the contrib.postgres lookups.
parent 74f02557
Loading
Loading
Loading
Loading
+10 −40
Original line number Diff line number Diff line
import json

from django.contrib.postgres import lookups
from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import Field, Lookup, Transform, IntegerField
from django.db.models import Field, Transform, IntegerField
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _

@@ -74,12 +75,6 @@ class ArrayField(Field):
            return [self.base_field.get_prep_value(i) for i in value]
        return value

    def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
        if lookup_type == 'contains':
            return [self.get_prep_value(value)]
        return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
                connection, prepared=False)

    def deconstruct(self):
        name, path, args, kwargs = super(ArrayField, self).deconstruct()
        if path == 'django.contrib.postgres.fields.array.ArrayField':
@@ -156,46 +151,21 @@ class ArrayField(Field):


@ArrayField.register_lookup
class ArrayContainsLookup(Lookup):
    lookup_name = 'contains'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        type_cast = self.lhs.output_field.db_type(connection)
        return '%s @> %s::%s' % (lhs, rhs, type_cast), params


@ArrayField.register_lookup
class ArrayContainedByLookup(Lookup):
    lookup_name = 'contained_by'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        return '%s <@ %s' % (lhs, rhs), params

class ArrayContains(lookups.DataContains):
    def as_sql(self, qn, connection):
        sql, params = super(ArrayContains, self).as_sql(qn, connection)
        sql += '::%s' % self.lhs.output_field.db_type(connection)
        return sql, params

@ArrayField.register_lookup
class ArrayOverlapLookup(Lookup):
    lookup_name = 'overlap'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        return '%s && %s' % (lhs, rhs), params
ArrayField.register_lookup(lookups.ContainedBy)
ArrayField.register_lookup(lookups.Overlap)


@ArrayField.register_lookup
class ArrayLenTransform(Transform):
    lookup_name = 'len'

    @property
    def output_field(self):
        return IntegerField()
    output_field = IntegerField()

    def as_sql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
+12 −52
Original line number Diff line number Diff line
import json

from django.contrib.postgres import forms
from django.contrib.postgres import forms, lookups
from django.contrib.postgres.fields.array import ArrayField
from django.core import exceptions
from django.db.models import Field, Lookup, Transform, TextField
from django.db.models import Field, Transform, TextField
from django.utils import six
from django.utils.translation import ugettext_lazy as _

@@ -21,12 +21,6 @@ class HStoreField(Field):
    def db_type(self, connection):
        return 'hstore'

    def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
        if lookup_type == 'contains':
            return [self.get_prep_value(value)]
        return super(HStoreField, self).get_db_prep_lookup(lookup_type, value,
                connection, prepared=False)

    def get_transform(self, name):
        transform = super(HStoreField, self).get_transform(name)
        if transform:
@@ -60,48 +54,20 @@ class HStoreField(Field):
        return super(HStoreField, self).formfield(**defaults)


@HStoreField.register_lookup
class HStoreContainsLookup(Lookup):
    lookup_name = 'contains'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        return '%s @> %s' % (lhs, rhs), params


@HStoreField.register_lookup
class HStoreContainedByLookup(Lookup):
    lookup_name = 'contained_by'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        return '%s <@ %s' % (lhs, rhs), params
HStoreField.register_lookup(lookups.DataContains)
HStoreField.register_lookup(lookups.ContainedBy)


@HStoreField.register_lookup
class HasKeyLookup(Lookup):
class HasKeyLookup(lookups.PostgresSimpleLookup):
    lookup_name = 'has_key'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        return '%s ? %s' % (lhs, rhs), params
    operator = '?'


@HStoreField.register_lookup
class HasKeysLookup(Lookup):
class HasKeysLookup(lookups.PostgresSimpleLookup):
    lookup_name = 'has_keys'

    def as_sql(self, compiler, connection):
        lhs, lhs_params = self.process_lhs(compiler, connection)
        rhs, rhs_params = self.process_rhs(compiler, connection)
        params = lhs_params + rhs_params
        return '%s ?& %s' % (lhs, rhs), params
    operator = '?&'


class KeyTransform(Transform):
@@ -126,20 +92,14 @@ class KeyTransformFactory(object):


@HStoreField.register_lookup
class KeysTransform(Transform):
class KeysTransform(lookups.FunctionTransform):
    lookup_name = 'keys'
    function = 'akeys'
    output_field = ArrayField(TextField())

    def as_sql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
        return 'akeys(%s)' % lhs, params


@HStoreField.register_lookup
class ValuesTransform(Transform):
class ValuesTransform(lookups.FunctionTransform):
    lookup_name = 'values'
    function = 'avals'
    output_field = ArrayField(TextField())

    def as_sql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
        return 'avals(%s)' % lhs, params
+32 −6
Original line number Diff line number Diff line
from django.db.models import Transform
from django.db.models import Lookup, Transform


class Unaccent(Transform):
class PostgresSimpleLookup(Lookup):
    def as_sql(self, qn, connection):
        lhs, lhs_params = self.process_lhs(qn, connection)
        rhs, rhs_params = self.process_rhs(qn, connection)
        params = lhs_params + rhs_params
        return '%s %s %s' % (lhs, self.operator, rhs), params


class FunctionTransform(Transform):
    def as_sql(self, qn, connection):
        lhs, params = qn.compile(self.lhs)
        return "%s(%s)" % (self.function, lhs), params


class DataContains(PostgresSimpleLookup):
    lookup_name = 'contains'
    operator = '@>'


class ContainedBy(PostgresSimpleLookup):
    lookup_name = 'contained_by'
    operator = '<@'


class Overlap(PostgresSimpleLookup):
    lookup_name = 'overlap'
    operator = '&&'


class Unaccent(FunctionTransform):
    bilateral = True
    lookup_name = 'unaccent'

    def as_postgresql(self, compiler, connection):
        lhs, params = compiler.compile(self.lhs)
        return "UNACCENT(%s)" % lhs, params
    function = 'UNACCENT'
+3 −9
Original line number Diff line number Diff line
@@ -746,7 +746,9 @@ class Field(RegisterLookupMixin):
            return QueryWrapper(('(%s)' % sql), params)

        if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute',
                           'second', 'search', 'regex', 'iregex'):
                           'second', 'search', 'regex', 'iregex', 'contains',
                           'icontains', 'iexact', 'startswith', 'endswith',
                           'istartswith', 'iendswith'):
            return [value]
        elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
            return [self.get_db_prep_value(value, connection=connection,
@@ -754,14 +756,6 @@ class Field(RegisterLookupMixin):
        elif lookup_type in ('range', 'in'):
            return [self.get_db_prep_value(v, connection=connection,
                                           prepared=prepared) for v in value]
        elif lookup_type in ('contains', 'icontains'):
            return ["%%%s%%" % connection.ops.prep_for_like_query(value)]
        elif lookup_type == 'iexact':
            return [connection.ops.prep_for_iexact_query(value)]
        elif lookup_type in ('startswith', 'istartswith'):
            return ["%s%%" % connection.ops.prep_for_like_query(value)]
        elif lookup_type in ('endswith', 'iendswith'):
            return ["%%%s" % connection.ops.prep_for_like_query(value)]
        elif lookup_type == 'isnull':
            return []
        elif lookup_type == 'year':
+51 −1
Original line number Diff line number Diff line
@@ -222,6 +222,14 @@ default_lookups['exact'] = Exact

class IExact(BuiltinLookup):
    lookup_name = 'iexact'

    def process_rhs(self, qn, connection):
        rhs, params = super(IExact, self).process_rhs(qn, connection)
        if params:
            params[0] = connection.ops.prep_for_iexact_query(params[0])
        return rhs, params


default_lookups['iexact'] = IExact


@@ -317,31 +325,73 @@ class PatternLookup(BuiltinLookup):

class Contains(PatternLookup):
    lookup_name = 'contains'

    def process_rhs(self, qn, connection):
        rhs, params = super(Contains, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
        return rhs, params


default_lookups['contains'] = Contains


class IContains(PatternLookup):
class IContains(Contains):
    lookup_name = 'icontains'


default_lookups['icontains'] = IContains


class StartsWith(PatternLookup):
    lookup_name = 'startswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(StartsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
        return rhs, params


default_lookups['startswith'] = StartsWith


class IStartsWith(PatternLookup):
    lookup_name = 'istartswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
        return rhs, params


default_lookups['istartswith'] = IStartsWith


class EndsWith(PatternLookup):
    lookup_name = 'endswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(EndsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
        return rhs, params


default_lookups['endswith'] = EndsWith


class IEndsWith(PatternLookup):
    lookup_name = 'iendswith'

    def process_rhs(self, qn, connection):
        rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
        if params and not self.bilateral_transforms:
            params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
        return rhs, params


default_lookups['iendswith'] = IEndsWith