Commit 18983f0e authored by Russell Keith-Magee's avatar Russell Keith-Magee
Browse files

Fixed #13003 -- Ensured that ._state.db is set correctly for select_related()...

Fixed #13003 -- Ensured that ._state.db is set correctly for select_related() queries. Thanks to Alex Gaynor for the report.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12701 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 3508a86d
Loading
Loading
Loading
Loading
+15 −9
Original line number Diff line number Diff line
@@ -267,7 +267,7 @@ class QuerySet(object):
        for row in compiler.results_iter():
            if fill_cache:
                obj, _ = get_cached_row(self.model, row,
                            index_start, max_depth,
                            index_start, using=self.db, max_depth=max_depth,
                            requested=requested, offset=len(aggregate_select),
                            only_load=only_load)
            else:
@@ -279,6 +279,9 @@ class QuerySet(object):
                    # Omit aggregates in object creation.
                    obj = self.model(*row[index_start:aggregate_start])

                # Store the source database of the object
                obj._state.db = self.db

            for i, k in enumerate(extra_select):
                setattr(obj, k, row[i])

@@ -286,9 +289,6 @@ class QuerySet(object):
            for i, aggregate in enumerate(aggregate_select):
                setattr(obj, aggregate, row[i+aggregate_start])

            # Store the source database of the object
            obj._state.db = self.db

            yield obj

    def aggregate(self, *args, **kwargs):
@@ -1112,7 +1112,7 @@ class EmptyQuerySet(QuerySet):
    value_annotation = False


def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
                   requested=None, offset=0, only_load=None):
    """
    Helper function that recursively returns an object with the specified
@@ -1126,6 +1126,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
     * row - the row of data returned by the database cursor
     * index_start - the index of the row at which data for this
       object is known to start
     * using - the database alias on which the query is being executed.
     * max_depth - the maximum depth to which a select_related()
       relationship should be explored.
     * cur_depth - the current depth in the select_related() tree.
@@ -1170,6 +1171,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
            obj = klass(**dict(zip(init_list, fields)))
        else:
            obj = klass(*fields)

    else:
        # Load all fields on klass
        field_count = len(klass._meta.fields)
@@ -1182,6 +1184,10 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
        else:
            obj = klass(*fields)

    # If an object was retrieved, set the database state.
    if obj:
        obj._state.db = using

    index_end = index_start + field_count + offset
    # Iterate over each related object, populating any
    # select_related() fields
@@ -1193,8 +1199,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
        else:
            next = None
        # Recursively retrieve the data for the related object
        cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
                cur_depth+1, next)
        cached_row = get_cached_row(f.rel.to, row, index_end, using,
                max_depth, cur_depth+1, next)
        # If the recursive descent found an object, populate the
        # descriptor caches relevant to the object
        if cached_row:
@@ -1222,8 +1228,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
                continue
            next = requested[f.related_query_name()]
            # Recursively retrieve the data for the related object
            cached_row = get_cached_row(model, row, index_end, max_depth,
                cur_depth+1, next)
            cached_row = get_cached_row(model, row, index_end, using,
                max_depth, cur_depth+1, next)
            # If the recursive descent found an object, populate the
            # descriptor caches relevant to the object
            if cached_row:
+14 −0
Original line number Diff line number Diff line
@@ -641,6 +641,20 @@ class QueryTestCase(TestCase):
        val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other')
        self.assertEqual(map(lambda o: o.pk, val), [dive.pk])

    def test_select_related(self):
        "Database assignment is retained if an object is retrieved with select_related()"
        # Create a book and author on the other database
        mark = Person.objects.using('other').create(name="Mark Pilgrim")
        dive = Book.objects.using('other').create(title="Dive into Python",
                                                  published=datetime.date(2009, 5, 4),
                                                  editor=mark)

        # Retrieve the Person using select_related()
        book = Book.objects.using('other').select_related('editor').get(title="Dive into Python")

        # The editor instance should have a db state
        self.assertEqual(book.editor._state.db, 'other')

class TestRouter(object):
    # A test router. The behaviour is vaguely master/slave, but the
    # databases aren't assumed to propagate changes.