view src/ltpdarepo/__init__.py @ 249:863e3e81498c

Merge with stable
author Daniele Nicolodi <daniele@grinta.net>
date Tue, 27 Dec 2011 19:00:04 +0100
parents e6ed4e03074f
children 5e11cd66721d
line wrap: on
line source

# Copyright 2011 Daniele Nicolodi <nicolodi@science.unitn.it>
#
# This software may be used and distributed according to the terms of
# the GNU Affero General Public License version 3 or any later version.


from datetime import datetime
from pkg_resources import get_distribution
from urlparse import urlparse, urljoin

from flask import Flask, g, request, session, render_template, Markup, redirect, flash, url_for, current_app
from werkzeug.exceptions import default_exceptions, InternalServerError, HTTPException

import MySQLdb as mysql
import MySQLdb.converters as converters
import dateutil.tz

from .security import secure, require, authenticate
from .utils import datetimetz
from .views.browse import module as browse
from .views.databases import module as databases
from .views.feed import url_for_atom_feed, module as feed
from .views.profile import module as profile
from .views.queries import module as queries
from .views.users import module as users


SCHEMA = 31


# customize mysql types conversion from and to DATETIME fields to
# return timezone aware datetime objects in the UTC timezone and
# correclty convert timezone aware datetime objects to UTC timezone
def datetime_or_none(s):
    value = converters.DateTime_or_None(s)
    if value is not None:
        value = datetimetz(value.year, value.month, value.day, value.hour,
                           value.minute, value.second, value.microsecond,
                           tzinfo=dateutil.tz.tzutc())
    return value

def datetime_to_literal(value, c):
    if value.tzinfo is not None:
        value = value.astimezone(dateutil.tz.tzutc())
    return converters.DateTime2literal(value, c)

conversions = converters.conversions.copy()
conversions[mysql.constants.FIELD_TYPE.DATETIME] = datetime_or_none
conversions[datetime] = datetime_to_literal
conversions[datetimetz] = datetime_to_literal


def before_request():
    # get version information from package
    g.version = get_distribution('ltpdarepo').version

    # optimization: do not open db connection for static resources
    if request.endpoint == 'static':
        return

    # open database connection
    config = current_app.config
    g.db = mysql.connect(host=config['HOSTNAME'], db=config['DATABASE'],
                         user=config['USERNAME'], passwd=config['PASSWORD'],
                         charset='utf8', conv=conversions)

    # validate schema revision
    curs = g.db.cursor()
    curs.execute("SELECT CAST(value AS UNSIGNED) FROM options WHERE name='version'")
    g.schema = curs.fetchone()[0]
    if g.schema != SCHEMA:
        raise InternalServerError(
            '<p>The database needs to be upgraded.</p><p>'
            'Current database schema version: %s. '
            'Required version: %s.</p>' % (g.schema, SCHEMA))


def teardown_request(exception):
    # close database connection
    db = getattr(g, 'db', None)
    if db is not None:
        db.close()


def error_handler(error):
    if not isinstance(error, HTTPException):
        # nicely report tracebacks
        import traceback
        error = InternalServerError()
        error.description += '<pre>' + traceback.format_exc() + '</pre>'
    return render_template('error.html', error=error), error.code


def breadcrumbs():
    url = []
    parts = []
    for item in request.path.split('/')[1:-1]:
        url.append(item)
        if item:
            parts.append((item, urljoin(url_for('index'), '/'.join(url))))
    out = ['<a href="%s">home</a>' % url_for('index'), ]
    for name, href in parts[1:]:
        out.append('<a href="%s">%s</a>' % (href, name))
    if len(out) > 1:
        return Markup(u' &#x00BB; '.join(out))
    return ''


def url_for_other_page(page):
    args = request.view_args.copy()
    args.update(request.args)
    args.update(p=page)
    return url_for(request.endpoint, **args)


def url_for_other_order(field):
    current = request.args.get('o', 'id')
    reverse = int(request.args.get('r', 0))
    if current == field:
        reverse = int(not reverse)
    args = request.view_args.copy()
    args.update(request.args)
    args.update(o=field, r=int(reverse))
    return url_for(request.endpoint, **args)


def url_for_other_size(size):
    args = request.view_args.copy()
    args.update(request.args)
    args.update(n=size)
    return url_for(request.endpoint, **args)


def is_safe_url(target):
    ref = urlparse(request.host_url)
    test = urlparse(urljoin(request.host_url, target))
    return test.scheme in ('http', 'https') and test.netloc == ref.netloc


class Application(Flask):
    def __init__(self, conf=None, **kwargs):
        super(Application, self).__init__(__name__)
        secure(self)

        # configuration
        self.config.from_pyfile('config.py')
        if conf is not None:
            self.config.from_pyfile(conf)
        self.config.update(kwargs)

        @self.route('/')
        @require('user')
        def index():
            curs = g.db.cursor()
            curs.execute("""SELECT DISTINCT Db FROM mysql.db, available_dbs
                            WHERE Select_priv='Y' AND User=%s AND Db=db_name
                            ORDER BY Db""", session['username'])
            dbs = [row[0] for row in curs.fetchall()]
            return render_template('index.html', databases=dbs)

        @self.route('/login', methods=['GET', 'POST'])
        def login():
            if request.method == 'POST':
                if authenticate(request.form['username'], request.form['password']):
                    session['username'] = request.form['username']
                    target = request.args.get('next')
                    if not target or not is_safe_url(target):
                        target = url_for('index')
                    return redirect(target)
                flash('Login failed.', category='error')

            return render_template('login.html')

        @self.route('/logout')
        def logout():
            session.pop('username', None)
            return redirect(url_for('index'))

        # database connection
        self.before_request(before_request)
        self.teardown_request(teardown_request)

        # template globals
        self.jinja_env.globals['breadcrumbs'] = breadcrumbs
        self.jinja_env.globals['url_for_other_page'] = url_for_other_page
        self.jinja_env.globals['url_for_other_order'] = url_for_other_order
        self.jinja_env.globals['url_for_other_size'] = url_for_other_size
        self.jinja_env.globals['url_for_atom_feed'] = url_for_atom_feed

        # error handlers
        for exc in default_exceptions:
            self.error_handler_spec[None][exc] = error_handler

        # blueprints
        self.register_blueprint(browse, url_prefix='/browse')
        self.register_blueprint(feed, url_prefix='/browse')
        self.register_blueprint(profile, url_prefix='/user')
        self.register_blueprint(databases, url_prefix='/manage/databases')
        self.register_blueprint(queries, url_prefix='/manage/queries')
        self.register_blueprint(users, url_prefix='/manage/users')


def main(conf=None):
    app = Application(conf)
    app.run()