changeset 41:ce7734a6b6ad

Consistently access database connection through Flask global variables.
author Daniele Nicolodi <daniele@grinta.net>
date Wed, 29 Jun 2011 01:46:06 +0200
parents e640b2302fab
children 5dfa71aadde8
files src/ltpdarepo/__init__.py src/ltpdarepo/query.py src/ltpdarepo/user.py
diffstat 3 files changed, 21 insertions(+), 39 deletions(-) [+]
line wrap: on
line diff
--- a/src/ltpdarepo/__init__.py	Wed Jun 29 00:40:07 2011 +0200
+++ b/src/ltpdarepo/__init__.py	Wed Jun 29 01:46:06 2011 +0200
@@ -13,35 +13,27 @@
 secure(app)
 
 
-def connection():
-    conn = getattr(g, 'db', None)
-    if conn is None:
-        conn = g.db = mysql.connect(host=app.config['HOSTNAME'], db=app.config['DATABASE'],
-                                    user=app.config['USERNAME'], passwd=app.config['PASSWORD'],
-                                    charset='utf8')
-    return conn
-
-
-# open a database connection at each request
 @app.before_request
-def dbconnect():
+def before_request():
     # open database connection
-    conn = connection()
+    g.db = mysql.connect(host=app.config['HOSTNAME'], db=app.config['DATABASE'],
+                         user=app.config['USERNAME'], passwd=app.config['PASSWORD'],
+                         charset='utf8')
 
     # get version information from package
     g.version = get_distribution('ltpdarepo').version
 
     # validate schema revision
-    curs = conn.cursor()
-    curs.execute("""SELECT value+0 FROM options WHERE name='version'""")
+    curs = g.db.cursor()
+    curs.execute("SELECT value+0 FROM options WHERE name='version'")
     g.schema = curs.fetchone()[0]
     #if g.schema != SCHEMA and '/static/' not in request.url:
     #    return render_template('error.html', error=u'500: Needs upgrade'), 500
 
 
-# close it at the end of the request
 @app.after_request
-def dbclose(response):
+def after_request(response):
+    # close database connection
     g.db.close()
     return response
 
@@ -130,4 +122,3 @@
 def main():
     app.config.from_pyfile('config.py')
     app.run(debug=True)
-
--- a/src/ltpdarepo/query.py	Wed Jun 29 00:40:07 2011 +0200
+++ b/src/ltpdarepo/query.py	Wed Jun 29 01:46:06 2011 +0200
@@ -2,7 +2,6 @@
 
 from MySQLdb.cursors import DictCursor
 
-from ltpdarepo import connection
 from ltpdarepo.form import Form
 from wtforms.fields import TextField, HiddenField
 from wtforms import validators
@@ -14,8 +13,7 @@
     name = TextField("Name", validators=[validators.Required()])
 
     def validate_name(form, field):
-        conn = connection()
-        curs = conn.cursor()
+        curs = g.db.cursor()
         query = Query.load(database=form.database.data, name=field.data)
         if query is not None:
             raise ValidationError(u"Query with this name already exists.")
@@ -27,15 +25,14 @@
         self.__dict__ = self
 
     def create(self):
-        conn = connection()
-        curs = con.cursor()
+        curs = g.db.cursor()
         curs.execute("""INSERT INTO queries (name, db, querystring)
                         VALUES (%(name)s, %(db)s, %(querystring)s)""", dict(self))
+        g.db.commit()
 
     @staticmethod
     def load(database, name):
-        conn = connection()
-        curs = conn.cursor(DictCursor)
+        curs = g.db.cursor(DictCursor)
         curs.execute("""SELECT querystring
                         FROM queries WHERE db=%s and name=%s""", (database, name))
         query = curs.fetchone()
--- a/src/ltpdarepo/user.py	Wed Jun 29 00:40:07 2011 +0200
+++ b/src/ltpdarepo/user.py	Wed Jun 29 01:46:06 2011 +0200
@@ -5,7 +5,6 @@
 
 from MySQLdb.cursors import DictCursor
 
-from ltpdarepo import connection
 from ltpdarepo.form import Form
 
 
@@ -63,8 +62,7 @@
 
     @staticmethod
     def load(username):
-        conn = connection()
-        curs = conn.cursor(DictCursor)
+        curs = g.db.cursor(DictCursor)
         curs.execute("""SELECT username,
                                given_name AS name,
                                family_name AS surname,
@@ -83,8 +81,7 @@
         if not self.password:
             self.password = _generate_password()
 
-        conn = connection()
-        curs = conn.cursor()
+        curs = g.db.cursor()
 
         for host in ('localhost', '%'):
             curs.execute("""CREATE USER %s@%s IDENTIFIED BY %s""",
@@ -96,11 +93,10 @@
                      (self.username, self.name, self.surname,
                       self.email, self.telephone, self.institution, self.admin))
 
-        conn.commit()
+        g.db.commit()
 
     def delete(self):
-        conn = connection()
-        curs = conn.cursor()
+        curs = g.db.cursor()
 
         curs.execute("""DELETE FROM users WHERE username=%s""", self.username)
         curs.execute("""SELECT Host FROM mysql.user WHERE User=%s""", self.username)
@@ -108,11 +104,10 @@
         for host in hosts:
             curs.execute("""DROP USER %s@%s""", (self.username, host))
 
-        conn.commit()
+        g.db.commit()
 
     def save(self):
-        conn = connection()
-        curs = conn.cursor()
+        curs = g.db.cursor()
 
         curs.execute("""UPDATE users SET given_name=%s, family_name=%s, email=%s,
                                          institution=%s, telephone=%s, is_admin=%s
@@ -120,7 +115,7 @@
                      (self.name, self.surname, self.email,
                       self.telephone, self.institution, self.admin, self.username))
 
-        conn.commit()
+        g.db.commit()
 
     def passwd(self, password=None):
         if password is not None:
@@ -128,8 +123,7 @@
         if not self.password:
             self.password = _generate_password()
 
-        conn = connection()
-        curs = conn.cursor()
+        curs = g.db.cursor()
 
         curs.execute("""SELECT Host FROM mysql.user WHERE User=%s""", self.username)
         hosts = [row[0] for row in curs.fetchall()]
@@ -137,4 +131,4 @@
             curs.execute("""SET PASSWORD FOR %s@%s = PASSWORD(%s)""",
                          (self.username, host, self.password))
 
-        conn.commit()
+        g.db.commit()