Commit 0b6f05ed authored by Anssi Kääriäinen's avatar Anssi Kääriäinen
Browse files

Fixed #19501 -- added Model.from_db() method

The Model.from_db() is intended to be used in cases where customization
of model loading is needed. Reasons can be performance, or adding custom
behavior to the model (for example "dirty field tracking" to issue
automatic update_fields when saving models).

A big thank you to Tim Graham for the review!
parent c26579ea
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -458,6 +458,16 @@ class Model(six.with_metaclass(ModelBase)):
        super(Model, self).__init__()
        signals.post_init.send(sender=self.__class__, instance=self)

    @classmethod
    def from_db(cls, db, field_names, values):
        if cls._deferred:
            new = cls(**dict(zip(field_names, values)))
        else:
            new = cls(*values)
        new._state.adding = False
        new._state.db = db
        return new

    def __repr__(self):
        try:
            u = six.text_type(self)
+51 −64
Original line number Diff line number Diff line
@@ -241,7 +241,6 @@ class QuerySet(object):
        aggregate_select = list(self.query.aggregate_select)

        only_load = self.query.get_loaded_field_names()
        if not fill_cache:
        fields = self.model._meta.concrete_fields

        load_fields = []
@@ -260,9 +259,6 @@ class QuerySet(object):
                    # Therefore, we need to load all fields from this model
                    load_fields.append(field.name)

        index_start = len(extra_select)
        aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields)

        skip = None
        if load_fields and not fill_cache:
            # Some fields have been deferred, so we have to initialize
@@ -275,30 +271,25 @@ class QuerySet(object):
                else:
                    init_list.append(field.attname)
            model_cls = deferred_class_factory(self.model, skip)
        else:
            model_cls = self.model
            init_list = [f.attname for f in fields]

        # Cache db and model outside the loop
        db = self.db
        model = self.model
        compiler = self.query.get_compiler(using=db)
        index_start = len(extra_select)
        aggregate_start = index_start + len(init_list)

        if fill_cache:
            klass_info = get_klass_info(model, max_depth=max_depth,
            klass_info = get_klass_info(model_cls, max_depth=max_depth,
                                        requested=requested, only_load=only_load)
        for row in compiler.results_iter():
            if fill_cache:
                obj, _ = get_cached_row(row, index_start, db, klass_info,
                                        offset=len(aggregate_select))
            else:
                # Omit aggregates in object creation.
                row_data = row[index_start:aggregate_start]
                if skip:
                    obj = model_cls(**dict(zip(init_list, row_data)))
                else:
                    obj = model(*row_data)

                # Store the source database of the object
                obj._state.db = db
                # This object came from the database; it's not being added.
                obj._state.adding = False
                obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start])

            if extra_select:
                for i, k in enumerate(extra_select):
@@ -1417,6 +1408,21 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
    return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx


def reorder_for_init(model, field_names, values):
    """
    Reorders given field names and values for those fields
    to be in the same order as model.__init__() expects to find them.
    """
    new_names, new_values = [], []
    for f in model._meta.concrete_fields:
        if f.attname not in field_names:
            continue
        new_names.append(f.attname)
        new_values.append(values[field_names.index(f.attname)])
    assert len(new_names) == len(field_names)
    return new_names, new_values


def get_cached_row(row, index_start, using, klass_info, offset=0,
                   parent_data=()):
    """
@@ -1451,18 +1457,19 @@ def get_cached_row(row, index_start, using, klass_info, offset=0,
         fields[pk_idx] == '')):
        obj = None
    elif field_names:
        fields = list(fields)
        values = list(fields)
        parent_values = []
        parent_field_names = []
        for rel_field, value in parent_data:
            field_names.append(rel_field.attname)
            fields.append(value)
        obj = klass(**dict(zip(field_names, fields)))
            parent_field_names.append(rel_field.attname)
            parent_values.append(value)
        field_names, values = reorder_for_init(
            klass, parent_field_names + field_names,
            parent_values + values)
        obj = klass.from_db(using, field_names, values)
    else:
        obj = klass(*fields)
    # If an object was retrieved, set the database state.
    if obj:
        obj._state.db = using
        obj._state.adding = False

        field_names = [f.attname for f in klass._meta.concrete_fields]
        obj = klass.from_db(using, field_names, fields)
    # Instantiate related fields
    index_end = index_start + field_count + offset
    # Iterate over each related object, populating any
@@ -1534,15 +1541,18 @@ class RawQuerySet(object):
        self.params = params or ()
        self.translations = translations or {}

    def __iter__(self):
        # Mapping of attrnames to row column positions. Used for constructing
        # the model using kwargs, needed when not all model's fields are present
        # in the query.
        model_init_field_names = {}
        # A list of tuples of (column name, column position). Used for
        # annotation fields.
        annotation_fields = []
    def resolve_model_init_order(self):
        """
        Resolve the init field names and value positions
        """
        model_init_names = [f.attname for f in self.model._meta.fields
                            if f.attname in self.columns]
        annotation_fields = [(column, pos) for pos, column in enumerate(self.columns)
                             if column not in self.model_fields]
        model_init_order = [self.columns.index(fname) for fname in model_init_names]
        return model_init_names, model_init_order, annotation_fields

    def __iter__(self):
        # Cache some things for performance reasons outside the loop.
        db = self.db
        compiler = connections[db].ops.compiler('SQLCompiler')(
@@ -1553,18 +1563,12 @@ class RawQuerySet(object):
        query = iter(self.query)

        try:
            # Find out which columns are model's fields, and which ones should be
            # annotated to the model.
            for pos, column in enumerate(self.columns):
                if column in self.model_fields:
                    model_init_field_names[self.model_fields[column].attname] = pos
                else:
                    annotation_fields.append((column, pos))
            model_init_names, model_init_pos, annotation_fields = self.resolve_model_init_order()

            # Find out which model's fields are not present in the query.
            skip = set()
            for field in self.model._meta.fields:
                if field.attname not in model_init_field_names:
                if field.attname not in model_init_names:
                    skip.add(field.attname)
            if skip:
                if self.model._meta.pk.attname in skip:
@@ -1572,34 +1576,17 @@ class RawQuerySet(object):
                model_cls = deferred_class_factory(self.model, skip)
            else:
                model_cls = self.model
                # All model's fields are present in the query. So, it is possible
                # to use *args based model instantiation. For each field of the model,
                # record the query column position matching that field.
                model_init_field_pos = []
                for field in self.model._meta.fields:
                    model_init_field_pos.append(model_init_field_names[field.attname])
            if need_resolv_columns:
                fields = [self.model_fields.get(c, None) for c in self.columns]
            # Begin looping through the query values.
            for values in query:
                if need_resolv_columns:
                    values = compiler.resolve_columns(values, fields)
                # Associate fields to values
                if skip:
                    model_init_kwargs = {}
                    for attname, pos in six.iteritems(model_init_field_names):
                        model_init_kwargs[attname] = values[pos]
                    instance = model_cls(**model_init_kwargs)
                else:
                    model_init_args = [values[pos] for pos in model_init_field_pos]
                    instance = model_cls(*model_init_args)
                model_init_values = [values[pos] for pos in model_init_pos]
                instance = model_cls.from_db(db, model_init_names, model_init_values)
                if annotation_fields:
                    for column, pos in annotation_fields:
                        setattr(instance, column, values[pos])

                instance._state.db = db
                instance._state.adding = False

                yield instance
        finally:
            # Done iterating the Query. If it has its own cursor, close it.
+54 −0
Original line number Diff line number Diff line
@@ -62,6 +62,60 @@ that, you need to :meth:`~Model.save()`.

        book = Book.objects.create_book("Pride and Prejudice")

Customizing model loading
-------------------------

.. classmethod:: Model.from_db(db, field_names, values)

.. versionadded:: 1.8

The ``from_db()`` method can be used to customize model instance creation
when loading from the database.

The ``db`` argument contains the database alias for the database the model
is loaded from, ``field_names`` contains the names of all loaded fields, and
``values`` contains the loaded values for each field in ``field_names``. The
``field_names`` are in the same order as the ``values``, so it is possible to
use ``cls(**(zip(field_names, values)))`` to instantiate the object. If all
of the model's fields are present, then ``values`` are guaranteed to be in
the order ``__init__()`` expects them. That is, the instance can be created
by ``cls(*values)``. It is possible to check if all fields are present by
consulting ``cls._deferred`` - if ``False``, then all fields have been loaded
from the database.

In addition to creating the new model, the ``from_db()`` method must set the
``adding`` and ``db`` flags in the new instance's ``_state`` attribute.

Below is an example showing how torecord the initial values of fields that
are loaded from the database:: 

    @classmethod
    def from_db(cls, db, field_names, values):
        # default implementation of from_db() (could be replaced
        # with super())
        if cls._deferred:
            instance = cls(**zip(field_names, values))
        else:
            instance = cls(*values)
        instance._state.adding = False
        instance._state.db = db
        # customization to store the original field values on the instance
        instance._loaded_values = zip(field_names, values)
        return instance

    def save(self, *args, **kwargs):
        # Check how the current values differ from ._loaded_values. For example,
        # prevent changing the creator_id of the model. (This example doesn't
        # support cases where 'creator_id' is deferred).
        if not self._state.adding and (
                self.creator_id != self._loaded_values['creator_id']):
            raise ValueError("Updating the value of creator isn't allowed")
        super(...).save(*args, **kwargs)

The example above shows a full ``from_db()`` implementation to clarify how that
is done. In this case it would of course be possible to just use ``super()`` call
in the ``from_db()`` method.

.. _validating-objects:

Validating objects
+4 −0
Original line number Diff line number Diff line
@@ -193,6 +193,10 @@ Models
  when these objects are unpickled in a different version than the one in
  which they were pickled.

* Added :meth:`Model.from_db() <django.db.models.Model.from_db()>` which
  Django uses whenever objects are loaded using the ORM. The method allows
  customizing model loading behavior.

Signals
^^^^^^^