Commit 58cd220f authored by Russell Keith-Magee's avatar Russell Keith-Magee
Browse files

Fixed #7270 -- Added the ability to follow reverse OneToOneFields in...

Fixed #7270 -- Added the ability to follow reverse OneToOneFields in select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 8e8d4b58
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object):
    # SingleRelatedObjectDescriptor instance.
    def __init__(self, related):
        self.related = related
        self.cache_name = '_%s_cache' % related.get_accessor_name()
        self.cache_name = related.get_cache_name()

    def __get__(self, instance, instance_type=None):
        if instance is None:
@@ -319,7 +319,7 @@ class ReverseSingleRelatedObjectDescriptor(object):
            # cache. This cache also might not exist if the related object
            # hasn't been accessed yet.
            if related:
                cache_name = '_%s_cache' % self.field.related.get_accessor_name()
                cache_name = self.field.related.get_cache_name()
                try:
                    delattr(related, cache_name)
                except AttributeError:
+73 −1
Original line number Diff line number Diff line
@@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
    """
    Helper function that recursively returns an object with the specified
    related attributes already populated.

    This method may be called recursively to populate deep select_related()
    clauses.

    Arguments:
     * klass - the class to retrieve (and instantiate)
     * row - the row of data returned by the database cursor
     * index_start - the index of the row at which data for this
       object is known to start
     * max_depth - the maximum depth to which a select_related()
       relationship should be explored.
     * cur_depth - the current depth in the select_related() tree.
       Used in recursive calls to determin if we should dig deeper.
     * requested - A dictionary describing the select_related() tree
       that is to be retrieved. keys are field names; values are
       dictionaries describing the keys on that related object that
       are themselves to be select_related().
     * offset - the number of additional fields that are known to
       exist in `row` for `klass`. This usually means the number of
       annotated results on `klass`.
     * only_load - if the query has had only() or defer() applied,
       this is the list of field names that will be returned. If None,
       the full field list for `klass` can be assumed.
    """
    if max_depth and requested is None and cur_depth > max_depth:
        # We've recursed deeply enough; stop now.
@@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
        # Handle deferred fields.
        skip = set()
        init_list = []
        pk_val = row[index_start + klass._meta.pk_index()]
        # Build the list of fields that *haven't* been requested
        for field in klass._meta.fields:
            if field.name not in load_fields:
                skip.add(field.name)
            else:
                init_list.append(field.attname)
        # Retrieve all the requested fields
        field_count = len(init_list)
        fields = row[index_start : index_start + field_count]
        # If all the select_related columns are None, then the related
        # object must be non-existent - set the relation to None.
        # Otherwise, construct the related object.
        if fields == (None,) * field_count:
            obj = None
        elif skip:
@@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
        else:
            obj = klass(*fields)
    else:
        # Load all fields on klass
        field_count = len(klass._meta.fields)
        fields = row[index_start : index_start + field_count]
        # If all the select_related columns are None, then the related
        # object must be non-existent - set the relation to None.
        # Otherwise, construct the related object.
        if fields == (None,) * field_count:
            obj = None
        else:
            obj = klass(*fields)

    index_end = index_start + field_count + offset
    # Iterate over each related object, populating any
    # select_related() fields
    for f in klass._meta.fields:
        if not select_related_descend(f, restricted, requested):
            continue
@@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
            next = requested[f.name]
        else:
            next = None
        # Recursively retrieve the data for the related object
        cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
                cur_depth+1, next)
        # If the recursive descent found an object, populate the
        # descriptor caches relevant to the object
        if cached_row:
            rel_obj, index_end = cached_row
            if obj is not None:
                # If the base object exists, populate the
                # descriptor cache
                setattr(obj, f.get_cache_name(), rel_obj)
            if f.unique:
                # If the field is unique, populate the
                # reverse descriptor cache on the related object
                setattr(rel_obj, f.related.get_cache_name(), obj)

    # Now do the same, but for reverse related objects.
    # Only handle the restricted case - i.e., don't do a depth
    # descent into reverse relations unless explicitly requested
    if restricted:
        related_fields = [
            (o.field, o.model)
            for o in klass._meta.get_all_related_objects()
            if o.field.unique
        ]
        for f, model in related_fields:
            if not select_related_descend(f, restricted, requested, reverse=True):
                continue
            next = requested[f.related_query_name()]
            # Recursively retrieve the data for the related object
            cached_row = get_cached_row(model, row, index_end, max_depth,
                cur_depth+1, next)
            # If the recursive descent found an object, populate the
            # descriptor caches relevant to the object
            if cached_row:
                rel_obj, index_end = cached_row
                if obj is not None:
                    # If the field is unique, populate the
                    # reverse descriptor cache
                    setattr(obj, f.related.get_cache_name(), rel_obj)
                if rel_obj is not None:
                    # If the related object exists, populate
                    # the descriptor cache.
                    setattr(rel_obj, f.get_cache_name(), obj)

    return obj, index_end

def delete_objects(seen_objs, using):
+14 −4
Original line number Diff line number Diff line
@@ -197,18 +197,28 @@ class DeferredAttribute(object):
        """
        instance.__dict__[self.field_name] = value

def select_related_descend(field, restricted, requested):
def select_related_descend(field, restricted, requested, reverse=False):
    """
    Returns True if this field should be used to descend deeper for
    select_related() purposes. Used by both the query construction code
    (sql.query.fill_related_selections()) and the model instance creation code
    (query.get_cached_row()).

    Arguments:
     * field - the field to be checked
     * restricted - a boolean field, indicating if the field list has been
       manually restricted using a requested clause)
     * requested - The select_related() dictionary.
     * reverse - boolean, True if we are checking a reverse select related
    """
    if not field.rel:
        return False
    if field.rel.parent_link:
    if field.rel.parent_link and not reverse:
        return False
    if restricted:
        if reverse and field.related_query_name() not in requested:
            return False
    if restricted and field.name not in requested:
        if not reverse and field.name not in requested:
            return False
    if not restricted and field.null:
        return False
+3 −0
Original line number Diff line number Diff line
@@ -45,3 +45,6 @@ class RelatedObject(object):
            return self.field.rel.related_name or (self.opts.object_name.lower() + '_set')
        else:
            return self.field.rel.related_name or (self.opts.object_name.lower())

    def get_cache_name(self):
        return "_%s_cache" % self.get_accessor_name()
+67 −1
Original line number Diff line number Diff line
@@ -520,7 +520,7 @@ class SQLCompiler(object):

        # Setup for the case when only particular related fields should be
        # included in the related selection.
        if requested is None and restricted is not False:
        if requested is None:
            if isinstance(self.query.select_related, dict):
                requested = self.query.select_related
                restricted = True
@@ -600,6 +600,72 @@ class SQLCompiler(object):
            self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
                    used, next, restricted, new_nullable, dupe_set, avoid)

        if restricted:
            related_fields = [
                (o.field, o.model)
                for o in opts.get_all_related_objects()
                if o.field.unique
            ]
            for f, model in related_fields:
                if not select_related_descend(f, restricted, requested, reverse=True):
                    continue
                # The "avoid" set is aliases we want to avoid just for this
                # particular branch of the recursion. They aren't permanently
                # forbidden from reuse in the related selection tables (which is
                # what "used" specifies).
                avoid = avoid_set.copy()
                dupe_set = orig_dupe_set.copy()
                table = model._meta.db_table

                int_opts = opts
                alias = root_alias
                alias_chain = []
                chain = opts.get_base_chain(f.rel.to)
                if chain is not None:
                    for int_model in chain:
                        # Proxy model have elements in base chain
                        # with no parents, assign the new options
                        # object and skip to the next base in that
                        # case
                        if not int_opts.parents[int_model]:
                            int_opts = int_model._meta
                            continue
                        lhs_col = int_opts.parents[int_model].column
                        dedupe = lhs_col in opts.duplicate_targets
                        if dedupe:
                            avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col),
                                ())
                            dupe_set.add((opts, lhs_col))
                        int_opts = int_model._meta
                        alias = self.query.join(
                            (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
                            exclusions=used, promote=True, reuse=used
                        )
                        alias_chain.append(alias)
                        for dupe_opts, dupe_col in dupe_set:
                            self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
                    dedupe = f.column in opts.duplicate_targets
                    if dupe_set or dedupe:
                        avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
                        if dedupe:
                            dupe_set.add((opts, f.column))
                alias = self.query.join(
                    (alias, table, f.rel.get_related_field().column, f.column),
                    exclusions=used.union(avoid),
                    promote=True
                )
                used.add(alias)
                columns, aliases = self.get_default_columns(start_alias=alias,
                    opts=model._meta, as_pairs=True)
                self.query.related_select_cols.extend(columns)
                self.query.related_select_fields.extend(model._meta.fields)

                next = requested.get(f.related_query_name(), {})
                new_nullable = f.null or None

                self.fill_related_selections(model._meta, table, cur_depth+1,
                    used, next, restricted, new_nullable)

    def deferred_to_columns(self):
        """
        Converts the self.deferred_loading data structure to mapping of table
Loading