Commit c3dc8379 authored by Ian Kelly's avatar Ian Kelly
Browse files

Fixed #10473: Added Oracle support for "RETURNING" ids from insert statements.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10044 bcc190cf-cafb-0310-a4f2-bffc1f526a37
parent 6d17020c
Loading
Loading
Loading
Loading
+12 −4
Original line number Diff line number Diff line
@@ -162,6 +162,14 @@ class BaseDatabaseOperations(object):
        """
        return None

    def fetch_returned_insert_id(self, cursor):
        """
        Given a cursor object that has just performed an INSERT...RETURNING
        statement into a table that has an auto-incrementing ID, returns the
        newly created ID.
        """
        return cursor.fetchone()[0]

    def field_cast_sql(self, db_type):
        """
        Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
@@ -249,10 +257,10 @@ class BaseDatabaseOperations(object):

    def return_insert_id(self):
        """
        For backends that support returning the last insert ID as part of an
        insert query, this method returns the SQL to append to the INSERT
        query. The returned fragment should contain a format string to hold
        hold the appropriate column.
        For backends that support returning the last insert ID as part
        of an insert query, this method returns the SQL and params to
        append to the INSERT query. The returned fragment should
        contain a format string to hold the appropriate column.
        """
        pass

+26 −3
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
    uses_custom_query_class = True
    interprets_empty_strings_as_nulls = True
    uses_savepoints = True
    can_return_id_from_insert = True


class DatabaseOperations(BaseDatabaseOperations):
@@ -97,6 +98,9 @@ WHEN (new.%(col_name)s IS NULL)
    def drop_sequence_sql(self, table):
        return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))

    def fetch_returned_insert_id(self, cursor):
        return long(cursor._insert_id_var.getvalue())

    def field_cast_sql(self, db_type):
        if db_type and db_type.endswith('LOB'):
            return "DBMS_LOB.SUBSTR(%s)"
@@ -152,6 +156,9 @@ WHEN (new.%(col_name)s IS NULL)
        connection.cursor()
        return connection.ops.regex_lookup(lookup_type)

    def return_insert_id(self):
        return "RETURNING %s INTO %%s", (InsertIdVar(),)

    def savepoint_create_sql(self, sid):
        return "SAVEPOINT " + self.quote_name(sid)

@@ -332,8 +339,11 @@ class OracleParam(object):
    parameter when executing the query.
    """

    def __init__(self, param, charset, strings_only=False):
        self.smart_str = smart_str(param, charset, strings_only)
    def __init__(self, param, cursor, strings_only=False):
        if hasattr(param, 'bind_parameter'):
            self.smart_str = param.bind_parameter(cursor)
        else:
            self.smart_str = smart_str(param, cursor.charset, strings_only)
        if hasattr(param, 'input_size'):
            # If parameter has `input_size` attribute, use that.
            self.input_size = param.input_size
@@ -344,6 +354,19 @@ class OracleParam(object):
            self.input_size = None


class InsertIdVar(object):
    """
    A late-binding cursor variable that can be passed to Cursor.execute
    as a parameter, in order to receive the id of the row created by an
    insert statement.
    """

    def bind_parameter(self, cursor):
        param = cursor.var(Database.NUMBER)
        cursor._insert_id_var = param
        return param


class FormatStylePlaceholderCursor(object):
    """
    Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
@@ -363,7 +386,7 @@ class FormatStylePlaceholderCursor(object):
        self.cursor.arraysize = 100

    def _format_params(self, params):
        return tuple([OracleParam(p, self.charset, True) for p in params])
        return tuple([OracleParam(p, self, True) for p in params])

    def _guess_input_sizes(self, params_list):
        sizes = [None] * len(params_list[0])
+1 −1
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
        return cursor.query

    def return_insert_id(self):
        return "RETURNING %s"
        return "RETURNING %s", ()

class DatabaseWrapper(BaseDatabaseWrapper):
    operators = {
+6 −3
Original line number Diff line number Diff line
@@ -306,17 +306,20 @@ class InsertQuery(Query):
        result = ['INSERT INTO %s' % qn(opts.db_table)]
        result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
        result.append('VALUES (%s)' % ', '.join(self.values))
        params = self.params
        if self.connection.features.can_return_id_from_insert:
            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
            result.append(self.connection.ops.return_insert_id() % col)
        return ' '.join(result), self.params
            r_fmt, r_params = self.connection.ops.return_insert_id()
            result.append(r_fmt % col)
            params = params + r_params
        return ' '.join(result), params

    def execute_sql(self, return_id=False):
        cursor = super(InsertQuery, self).execute_sql(None)
        if not (return_id and cursor):
            return
        if self.connection.features.can_return_id_from_insert:
            return cursor.fetchone()[0]
            return self.connection.ops.fetch_returned_insert_id(cursor)
        return self.connection.ops.last_insert_id(cursor,
                self.model._meta.db_table, self.model._meta.pk.column)