Commit 04240b23 authored by acrefoot's avatar acrefoot Committed by Tim Graham
Browse files

Refs #19527 -- Allowed QuerySet.bulk_create() to set the primary key of its objects.

PostgreSQL support only.

Thanks Vladislav Manchev and alesasnouski for working on the patch.
parent 60633ef3
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object):

    can_use_chunked_reads = True
    can_return_id_from_insert = False
    can_return_ids_from_bulk_insert = False
    has_bulk_insert = False
    uses_savepoints = False
    can_release_savepoints = False
+1 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from django.db.utils import InterfaceError
class DatabaseFeatures(BaseDatabaseFeatures):
    allows_group_by_selected_pks = True
    can_return_id_from_insert = True
    can_return_ids_from_bulk_insert = True
    has_real_datatype = True
    has_native_uuid_field = True
    has_native_duration_field = True
+8 −0
Original line number Diff line number Diff line
@@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations):
    def deferrable_sql(self):
        return " DEFERRABLE INITIALLY DEFERRED"

    def fetch_returned_insert_ids(self, cursor):
        """
        Given a cursor object that has just performed an INSERT...RETURNING
        statement into a table that has an auto-incrementing ID, return the
        list of newly created IDs.
        """
        return [item[0] for item in cursor.fetchall()]

    def lookup_cast(self, lookup_type, internal_type=None):
        lookup = '%s'

+33 −16
Original line number Diff line number Diff line
@@ -411,17 +411,21 @@ class QuerySet(object):
        Inserts each of the instances into the database. This does *not* call
        save() on each of the instances, does not send any pre/post save
        signals, and does not set the primary key attribute if it is an
        autoincrement field. Multi-table models are not supported.
        """
        # So this case is fun. When you bulk insert you don't get the primary
        # keys back (if it's an autoincrement), so you can't insert into the
        # child tables which references this. There are two workarounds, 1)
        # this could be implemented if you didn't have an autoincrement pk,
        # and 2) you could do it by doing O(n) normal inserts into the parent
        # tables to get the primary keys back, and then doing a single bulk
        # insert into the childmost table. Some databases might allow doing
        # this by using RETURNING clause for the insert query. We're punting
        # on these for now because they are relatively rare cases.
        autoincrement field (except if features.can_return_ids_from_bulk_insert=True).
        Multi-table models are not supported.
        """
        # When you bulk insert you don't get the primary keys back (if it's an
        # autoincrement, except if can_return_ids_from_bulk_insert=True), so
        # you can't insert into the child tables which references this. There
        # are two workarounds:
        # 1) This could be implemented if you didn't have an autoincrement pk
        # 2) You could do it by doing O(n) normal inserts into the parent
        #    tables to get the primary keys back and then doing a single bulk
        #    insert into the childmost table.
        # We currently set the primary keys on the objects when using
        # PostgreSQL via the RETURNING ID clause. It should be possible for
        # Oracle as well, but the semantics for  extracting the primary keys is
        # trickier so it's not done yet.
        assert batch_size is None or batch_size > 0
        # Check that the parents share the same concrete model with the our
        # model to detect the inheritance pattern ConcreteGrandParent ->
@@ -447,7 +451,11 @@ class QuerySet(object):
                    self._batched_insert(objs_with_pk, fields, batch_size)
                if objs_without_pk:
                    fields = [f for f in fields if not isinstance(f, AutoField)]
                    self._batched_insert(objs_without_pk, fields, batch_size)
                    ids = self._batched_insert(objs_without_pk, fields, batch_size)
                    if connection.features.can_return_ids_from_bulk_insert:
                        assert len(ids) == len(objs_without_pk)
                    for i in range(len(ids)):
                        objs_without_pk[i].pk = ids[i]

        return objs

@@ -1051,10 +1059,19 @@ class QuerySet(object):
            return
        ops = connections[self.db].ops
        batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
        for batch in [objs[i:i + batch_size]
                      for i in range(0, len(objs), batch_size)]:
            self.model._base_manager._insert(batch, fields=fields,
                                             using=self.db)
        inserted_ids = []
        for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
            if connections[self.db].features.can_return_ids_from_bulk_insert:
                inserted_id = self.model._base_manager._insert(
                    item, fields=fields, using=self.db, return_id=True
                )
                if len(objs) > 1:
                    inserted_ids.extend(inserted_id)
                if len(objs) == 1:
                    inserted_ids.append(inserted_id)
            else:
                self.model._base_manager._insert(item, fields=fields, using=self.db)
        return inserted_ids

    def _clone(self, **kwargs):
        query = self.query.clone()
+14 −4
Original line number Diff line number Diff line
@@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler):
        placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)

        if self.return_id and self.connection.features.can_return_id_from_insert:
            if self.connection.features.can_return_ids_from_bulk_insert:
                result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
                params = param_rows
            else:
                result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
                params = param_rows[0]
            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
            result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
            r_fmt, r_params = self.connection.ops.return_insert_id()
            # Skip empty r_fmt to allow subclasses to customize behavior for
            # 3rd party backends. Refs #19096.
            if r_fmt:
                result.append(r_fmt % col)
                params += r_params
            return [(" ".join(result), tuple(params))]
            return [(" ".join(result), tuple(chain.from_iterable(params)))]

        if can_bulk:
            result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
@@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler):
            ]

    def execute_sql(self, return_id=False):
        assert not (return_id and len(self.query.objs) != 1)
        assert not (
            return_id and len(self.query.objs) != 1 and
            not self.connection.features.can_return_ids_from_bulk_insert
        )
        self.return_id = return_id
        with self.connection.cursor() as cursor:
            for sql, params in self.as_sql():
                cursor.execute(sql, params)
            if not (return_id and cursor):
                return
            if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1:
                return self.connection.ops.fetch_returned_insert_ids(cursor)
            if self.connection.features.can_return_id_from_insert:
                assert len(self.query.objs) == 1
                return self.connection.ops.fetch_returned_insert_id(cursor)
            return self.connection.ops.last_insert_id(cursor,
                    self.query.get_meta().db_table, self.query.get_meta().pk.column)
Loading