'Decorator for Sqlite3

I have a simple sqlite3 function

from sqlite3 import connect


def foo(name):
    conn = connect("data.db")
    curs = conn.cursor()
    curs.execute(f"CREATE TABLE IF NOT EXISTS {name}(test TEXT PRIMARY KEY);")
    conn.commit()
    conn.close()

I want to have a decorator, so that I can write

from sqlite3 import connect

@db_connect
def foo(name):  # Don't know how to pass the args
    curs.execute(f"CREATE TABLE IF NOT EXISTS {name}(test TEXT PRIMARY KEY);")

The goal is, that I don't have to get a connection, close it, etc.

What I've tried:

def db_connect(func):
    def _db_connect(*args, **kwargs):
        conn = connect("data.db")
        curs = conn.cursor()
        result = func(*args, **kwargs)
        conn.commit()
        conn.close()
        return result
    return _db_connect

But now I am a bit stuck, because how to pass the cursor, to the function and would my decorator work?



Solution 1:[1]

What you actually need is a context manager, not a decorator.

import sqlite3
from contextlib import contextmanager

@contextmanager
def db_ops(db_name):
    conn = sqlite3.connect(db_name)
    cur = conn.cursor()
    yield cur
    conn.commit()
    conn.close()



with db_ops('db_path') as cur:
    cur.execute('create table if not exists temp (id int, name text)')

with db_ops('db_path') as cur:
    rows = [(1, 'a'), (2, 'b'), (3, 'c')]
    cur.executemany('insert into temp values (?, ?)', rows)

with db_ops('db_path') as cur:
    print(list(cur.execute('select * from temp')))

Output

[(1, 'a'), (2, 'b'), (3, 'c')]

As you can see you dont have to commit or create connection anymore.

It is worth noting that the the connection object supports the context manager protocol by default, meaning you can do this

conn = sqlite3.connect(...)
with conn:
    ...

But this only commits, it does not close the connection, you still have to use conn.close().

Solution 2:[2]

If you want to use decorator anyway, just pass created cursor to function inside wrapper:

from sqlite3 import connect

def db_connect(func):
    def _db_connect(*args, **kwargs):
        conn = connect("database.db")
        curs = conn.cursor()
        result = func(curs, *args, **kwargs)
        conn.commit()
        conn.close()
        return result
    return _db_connect

@db_connect
def create_table(curs):
    curs.execute("""
        CREATE TABLE IF NOT EXISTS testTable (
            id INTEGER PRIMARY KEY,
            test_text TEXT NOT NULL
            );""")
    return "table created"

@db_connect
def insert_item(curs, item):
    curs.execute("INSERT INTO testTable(test_text) VALUES (:item)",{"item": item})
    return f"{item} inserted"

@db_connect
def select_all(curs):
    result = curs.execute("SELECT * from testTable")
    return result.fetchall()

print(create_table())
print(insert_item("item1"))
print(insert_item("item2"))
print(select_all())

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Przemek