Commit a2dd618e authored by Marc Tamlyn's avatar Marc Tamlyn
Browse files

Fixed #22648 -- Transform.output_type should respect overridden custom_lookup and custom_transform.

Previously, class lookups from the output_type would be used, but any
changes to custom_lookup or custom_transform would be ignored.
parent 11932e97
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -22,18 +22,20 @@ class RegisterLookupMixin(object):
        except AttributeError:
            # This class didn't have any class_lookups
            pass
        if hasattr(self, 'output_type'):
            return self.output_type.get_lookup(lookup_name)
        return None

    def get_lookup(self, lookup_name):
        found = self._get_lookup(lookup_name)
        if found is None and hasattr(self, 'output_type'):
            return self.output_type.get_lookup(lookup_name)
        if found is not None and not issubclass(found, Lookup):
            return None
        return found

    def get_transform(self, lookup_name):
        found = self._get_lookup(lookup_name)
        if found is None and hasattr(self, 'output_type'):
            return self.output_type.get_transform(lookup_name)
        if found is not None and not issubclass(found, Transform):
            return None
        return found
+60 −0
Original line number Diff line number Diff line
@@ -89,6 +89,47 @@ class YearLte(models.lookups.LessThanOrEqual):
YearTransform.register_lookup(YearLte)


class SQLFunc(models.Lookup):
    def __init__(self, name, *args, **kwargs):
        super(SQLFunc, self).__init__(*args, **kwargs)
        self.name = name

    def as_sql(self, qn, connection):
        return '%s()', [self.name]

    @property
    def output_type(self):
        return CustomField()


class SQLFuncFactory(object):

    def __init__(self, name):
        self.name = name

    def __call__(self, *args, **kwargs):
        return SQLFunc(self.name, *args, **kwargs)


class CustomField(models.Field):

    def get_lookup(self, lookup_name):
        if lookup_name.startswith('lookupfunc_'):
            key, name = lookup_name.split('_', 1)
            return SQLFuncFactory(name)
        return super(CustomField, self).get_lookup(lookup_name)

    def get_transform(self, lookup_name):
        if lookup_name.startswith('transformfunc_'):
            key, name = lookup_name.split('_', 1)
            return SQLFuncFactory(name)
        return super(CustomField, self).get_transform(lookup_name)


class CustomModel(models.Model):
    field = CustomField()


# We will register this class temporarily in the test method.


@@ -341,3 +382,22 @@ class LookupTransformCallOrderTests(TestCase):

        finally:
            models.DateField._unregister_lookup(TrackCallsYearTransform)


class CustomisedMethodsTests(TestCase):

    def test_overridden_get_lookup(self):
        q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
        self.assertIn('monkeys()', str(q.query))

    def test_overridden_get_transform(self):
        q = CustomModel.objects.filter(field__transformfunc_banana=3)
        self.assertIn('banana()', str(q.query))

    def test_overridden_get_lookup_chain(self):
        q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3)
        self.assertIn('elephants()', str(q.query))

    def test_overridden_get_transform_chain(self):
        q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3)
        self.assertIn('pear()', str(q.query))