Commit 7171bf75 authored by Josh Smeaton's avatar Josh Smeaton
Browse files

Refs #14030 -- Added repr methods to all expressions

parent f218a2ff
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -94,6 +94,13 @@ class Count(Aggregate):
        super(Count, self).__init__(
            expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)

    def __repr__(self):
        return "{}({}, distinct={})".format(
            self.__class__.__name__,
            self.arg_joiner.join(str(arg) for arg in self.source_expressions),
            'False' if self.extra['distinct'] == '' else 'True',
        )

    def convert_value(self, value, connection, context):
        if value is None:
            return 0
@@ -117,6 +124,13 @@ class StdDev(Aggregate):
        self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
        super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)

    def __repr__(self):
        return "{}({}, sample={})".format(
            self.__class__.__name__,
            self.arg_joiner.join(str(arg) for arg in self.source_expressions),
            'False' if self.function == 'STDDEV_POP' else 'True',
        )

    def convert_value(self, value, connection, context):
        if value is None:
            return value
@@ -135,6 +149,13 @@ class Variance(Aggregate):
        self.function = 'VAR_SAMP' if sample else 'VAR_POP'
        super(Variance, self).__init__(expression, output_field=FloatField(), **extra)

    def __repr__(self):
        return "{}({}, sample={})".format(
            self.__class__.__name__,
            self.arg_joiner.join(str(arg) for arg in self.source_expressions),
            'False' if self.function == 'VAR_POP' else 'True',
        )

    def convert_value(self, value, connection, context):
        if value is None:
            return value
+45 −5
Original line number Diff line number Diff line
@@ -340,6 +340,12 @@ class Expression(ExpressionNode):
        self.lhs = lhs
        self.rhs = rhs

    def __repr__(self):
        return "<{}: {}>".format(self.__class__.__name__, self)

    def __str__(self):
        return "{} {} {}".format(self.lhs, self.connector, self.rhs)

    def get_source_expressions(self):
        return [self.lhs, self.rhs]

@@ -408,7 +414,7 @@ class DurationExpression(Expression):
        return expression_wrapper % sql, expression_params


class F(CombinableMixin):
class F(Combinable):
    """
    An object capable of resolving references to existing query objects.
    """
@@ -419,6 +425,9 @@ class F(CombinableMixin):
        """
        self.name = name

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, self.name)

    def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
        return query.resolve_ref(self.name, allow_joins, reuse, summarize)

@@ -446,6 +455,13 @@ class Func(ExpressionNode):
        self.source_expressions = self._parse_expressions(*expressions)
        self.extra = extra

    def __repr__(self):
        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
        extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items())
        if extra:
            return "{}({}, {})".format(self.__class__.__name__, args, extra)
        return "{}({})".format(self.__class__.__name__, args)

    def get_source_expressions(self):
        return self.source_expressions

@@ -504,6 +520,9 @@ class Value(ExpressionNode):
        super(Value, self).__init__(output_field=output_field)
        self.value = value

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, self.value)

    def as_sql(self, compiler, connection):
        connection.ops.check_expression_support(self)
        val = self.value
@@ -545,6 +564,9 @@ class RawSQL(ExpressionNode):
        self.sql, self.params = sql, params
        super(RawSQL, self).__init__(output_field=output_field)

    def __repr__(self):
        return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)

    def as_sql(self, compiler, connection):
        return '(%s)' % self.sql, self.params

@@ -556,6 +578,9 @@ class Random(ExpressionNode):
    def __init__(self):
        super(Random, self).__init__(output_field=fields.FloatField())

    def __repr__(self):
        return "Random()"

    def as_sql(self, compiler, connection):
        return connection.ops.random_function_sql(), []

@@ -567,6 +592,10 @@ class Col(ExpressionNode):
        super(Col, self).__init__(output_field=source)
        self.alias, self.target = alias, target

    def __repr__(self):
        return "{}({}, {})".format(
            self.__class__.__name__, self.alias, self.target)

    def as_sql(self, compiler, connection):
        qn = compiler.quote_name_unless_alias
        return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
@@ -588,8 +617,10 @@ class Ref(ExpressionNode):
    """
    def __init__(self, refs, source):
        super(Ref, self).__init__()
        self.source = source
        self.refs = refs
        self.refs, self.source = refs, source

    def __repr__(self):
        return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)

    def get_source_expressions(self):
        return [self.source]
@@ -743,6 +774,9 @@ class Date(ExpressionNode):
        self.col = None
        self.lookup_type = lookup_type

    def __repr__(self):
        return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)

    def get_source_expressions(self):
        return [self.col]

@@ -792,6 +826,10 @@ class DateTime(ExpressionNode):
            self.tzname = timezone._get_timezone_name(tzinfo)
        self.tzinfo = tzinfo

    def __repr__(self):
        return "{}({}, {}, {})".format(
            self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)

    def get_source_expressions(self):
        return [self.col]

@@ -833,8 +871,6 @@ class DateTime(ExpressionNode):

class OrderBy(BaseExpression):
    template = '%(expression)s %(ordering)s'
    descending_template = 'DESC'
    ascending_template = 'ASC'

    def __init__(self, expression, descending=False):
        self.descending = descending
@@ -842,6 +878,10 @@ class OrderBy(BaseExpression):
            raise ValueError('expression must be an expression type')
        self.expression = expression

    def __repr__(self):
        return "{}({}, descending={})".format(
            self.__class__.__name__, self.expression, self.descending)

    def set_source_expressions(self, exprs):
        self.expression = exprs[0]

+45 −1
Original line number Diff line number Diff line
@@ -6,10 +6,17 @@ import uuid

from django.core.exceptions import FieldError
from django.db import connection, transaction, DatabaseError
from django.db.models import F, Value, TimeField, UUIDField
from django.db.models import TimeField, UUIDField
from django.db.models.aggregates import Avg, Count, Max, Min, StdDev, Sum, Variance
from django.db.models.expressions import (
    Case, Col, Date, DateTime, F, Func, OrderBy,
    Random, RawSQL, Ref, Value, When
)
from django.db.models.functions import Coalesce, Concat, Length, Lower, Substr, Upper
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate
from django.utils import six
from django.utils.timezone import utc

from .models import Company, Employee, Number, Experiment, Time, UUID

@@ -812,3 +819,40 @@ class ValueTests(TestCase):
        UUID.objects.create()
        UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
        self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))


class ReprTests(TestCase):

    def test_expressions(self):
        self.assertEqual(
            repr(Case(When(a=1))),
            "<Case: CASE WHEN <Q: (AND: ('a', 1))> THEN Value(None), ELSE Value(None)>"
        )
        self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)")
        self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)")
        self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, UTC)")
        self.assertEqual(repr(F('published')), "F(published)")
        self.assertEqual(repr(F('cost') + F('tax')), "<Expression: F(cost) + F(tax)>")
        self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
        self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
        self.assertEqual(repr(Random()), "Random()")
        self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
        self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
        self.assertEqual(repr(Value(1)), "Value(1)")

    def test_functions(self):
        self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")
        self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))")
        self.assertEqual(repr(Length('a')), "Length(F(a))")
        self.assertEqual(repr(Lower('a')), "Lower(F(a))")
        self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))")
        self.assertEqual(repr(Upper('a')), "Upper(F(a))")

    def test_aggregates(self):
        self.assertEqual(repr(Avg('a')), "Avg(F(a))")
        self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
        self.assertEqual(repr(Max('a')), "Max(F(a))")
        self.assertEqual(repr(Min('a')), "Min(F(a))")
        self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
        self.assertEqual(repr(Sum('a')), "Sum(F(a))")
        self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")