view LTPDAConnectionManager.m @ 2:b71833fb33ef

More details on the utils.mysql package.
author Daniele Nicolodi <daniele@science.unitn.it>
date Sun, 23 May 2010 22:09:23 +0200
parents d5fef23867bb
children c706c10a76bd
line wrap: on
line source

classdef LTPDAConnectionManager < handle

  properties(SetAccess=private)

    connections = {};
    credentials = {};

  end % private properties

  properties(Dependent=true)

    credentialsExpiry; % seconds
    cachePassword; % 0=no 1=yes 2=ask
    maxConnectionsNumber;

  end % dependent properties

  methods(Static)

    function reset()
      setappdata(0, LTPDAConnectionManager.appdataKey, []);
    end

    function key = appdataKey()
      % defined as static method to be acessible by the reset static method
      key = 'LTPDAConnectionManager';
    end

  end % static methods

  methods

    function cm = LTPDAConnectionManager(pl)

      % load state from appdata
      acm = getappdata(0, cm.appdataKey());

      if isempty(acm)
        % take those from user preferences
        cm.credentials{end+1} = credentials('localhost', 'one');
        cm.credentials{end+1} = credentials('localhost', 'two', 'daniele');

        % store state in appdata
        setappdata(0, cm.appdataKey(), cm);

        import utils.const.*
        utils.helper.msg(msg.PROC1, 'new connection manager');
      else
        cm = acm;
      end
    end


    function val = get.credentialsExpiry(cm)
      % obtain from user preferences
      p = getappdata(0, 'LTPDApreferences');
      val = p.cm.credentialsExpiry;
    end


    function val = get.cachePassword(cm)
      % obtain from user preferences
      p = getappdata(0, 'LTPDApreferences');
      val = p.cm.cachePassword;
    end


    function val = get.maxConnectionsNumber(cm)
      % obtain from user preferences
      p = getappdata(0, 'LTPDApreferences');
      val = p.cm.maxConnectionsNumber;
    end


    function n = count(cm)
      import utils.const.*
      % find closed connections in the pool
      mask = false(numel(cm.connections), 1);
      for kk = 1:numel(cm.connections)
        if cm.connections{kk}.isClosed()
          utils.helper.msg(msg.PROC1, 'connection id=%d closed', kk);
          mask(kk) = true;
        end
      end

      % remove them
      cm.connections(mask) = [];

      % count remainig ones
      n = numel(cm.connections);
    end


    function clear(cm)
      % remove all cached credentials
      cm.credentials = {};
    end


    function conn = connect(cm, varargin)
      import utils.const.*

      % save current credentials cache
      cache = cm.credentials;

      % count open connections in the pool
      count = cm.count();

      % check parameters
      if numel(varargin) == 1 && isa(varargin{1}, 'plist')

        % extract parameters from plist
        pl = varargin{1};

        % check if we have a connection parameter
        conn = find(pl, 'connection');
        if ~isempty(conn)
          % check that it implements java.sql.Connection interface
          if ~isa(conn, 'java.sql.Connection')
            error('### connection is not valid database connection');
          end
          % return this connection
          return;
        end

        % otherwise
        hostname = find(pl, 'hostname');
        database = find(pl, 'database');
        username = find(pl, 'username');
        password = find(pl, 'password');

        % if there is no hostname and database ignore other parameters
        if ~ischar(hostname) || ~ischar(database)
          varargin = {};
        end
        % password can not be null but can be an empty string
        if ~ischar(password)
          varargin = {hostname, database, username};
        else
          varargin = {hostname, database, username, password};
        end
      end

      % check number of connections
      if count > cm.maxConnectionsNumber
        error('### too many open connections');
      end

      % connect
      try
        conn = cm.getConnection(varargin{:});
      catch ex
        % restore our copy of the credentials cache
        utils.helper.msg(msg.PROC1, 'undo cache changes');
        cm.credentials = cache;

        % hide implementation details
        %ex.throwAsCaller();
        ex.rethrow()
      end
    end


    function close(cm, ids)
      if nargin < 2
        ids = 1:numel(cm.connections);
      end
      cellfun(@close, cm.connections(ids));
    end


    function add(cm, c)
      if nargin < 2 || ~isa(c, 'credentials')
        error('### invalid call');
      end
      cm.cacheCredentials(c);
    end

  end % methods

  methods(Access=private)

    function conn = getConnection(cm, varargin)

      import utils.const.*

      switch numel(varargin)
        case 0
          [hostname, database, username] = cm.selectDatabase();
          conn = cm.getConnection(hostname, database, username);

        case 2
          conn = cm.getConnection(varargin{1}, varargin{2}, []);

        case 3
          % find credentials
          cred = cm.findCredentials(varargin{1}, varargin{2}, varargin{3});
          if isempty(cred)
            % no credentials found
            cred = credentials(varargin{1}, varargin{2}, varargin{3});
          else
            utils.helper.msg(msg.PROC1, 'use cached credentials');
          end

          cache = false;
          if numel(cred) > 1 || ~cred.complete
            % ask for which username and password to use
            [username, password, cache] = cm.inputCredentials(cred);

            % cache credentials
            cred = credentials(varargin{1}, varargin{2}, username);
            cm.cacheCredentials(cred);

            % add password to credentials
            cred.password = password;
          end

          % try to connect
          conn = cm.getConnection(cred.hostname, cred.database, cred.username, cred.password);

          % cache password
          if cache
            utils.helper.msg(msg.PROC1, 'cache password');
            cm.cacheCredentials(cred);
          end

        case 4
          try
            % connect
            conn = connect(varargin{1}, varargin{2}, varargin{3}, varargin{4});

            % cache credentials without password
            cred = credentials(varargin{1}, varargin{2}, varargin{3}, []);
            cm.cacheCredentials(cred);

          catch ex
            % look for access denied errors
            if strcmp(ex.identifier, 'utils:jmysql:connect:AccessDenied')
              % ask for new new credentials
              utils.helper.msg(msg.PROC1, ex.message);
              conn = cm.getConnection(varargin{1}, varargin{2}, varargin{3});
            else
              % error out
              throw(MException('', '### connection error').addCause(ex));
            end
          end

          % add connection to pool
          utils.helper.msg(msg.PROC1, 'add connection to pool');
          cm.connections{end+1} = conn;

        otherwise
          error('### invalid call')
      end

    end


    function ids = findCredentialsId(cm, varargin)
      import utils.const.*
      ids = [];

      for kk = 1:numel(cm.credentials)
        % invalidate expired passwords
        if expired(cm.credentials{kk})
          utils.helper.msg(msg.PROC1, 'cache entry id=%d expired', kk);
          cm.credentials{kk}.password = [];
          cm.credentials{kk}.expiry = 0;
        end

        % match input with cache
        if match(cm.credentials{kk}, varargin{:})
          ids = [ ids kk ];
        end
      end
    end


    function cred = findCredentials(cm, varargin)
      % default
      cred = [];

      % search
      ids = findCredentialsId(cm, varargin{:});

      % return an array credentials
      if ~isempty(ids)
        cred = [ cm.credentials{ids} ];
      end
    end


    function cacheCredentials(cm, c)
      import utils.const.*

      % find entry to update
      id = findCredentialsId(cm, c.hostname, c.database, c.username);

      % sanity check
      if numel(id) > 1
        error('### more than one cache entry for %s', char(c, 'short'));
      end

      % set password expiry time
      if ischar(c.password)
        c.expiry = double(time()) + cm.credentialsExpiry;
      end

      if isempty(id)
        % add at the end
        utils.helper.msg(msg.PROC1, 'add cache entry %s', char(c));
        cm.credentials{end+1} = c;
      else
        % update only if the cached informations are less than the one we have
        if ~complete(cm.credentials{id})
          utils.helper.msg(msg.PROC1, 'update cache entry id=%d %s', id, char(c));
          cm.credentials{id} = c;
        else
          % always update expiry time
          cm.credentials{id}.expiry = c.expiry;
        end
      end
    end


    function [username, password, cache] = inputCredentials(cm, cred)
      % this is a stubb

      % build a cell array of usernames and passwords
      users = { cred(:).username };
      passw = { cred(:).password };

      % default to the latest used username
      [e, ids] = sort([ cred(:).expiry ]);
      default = users{ids(1)};

      username = choose('Username', users, default);

      % pick the corresponding password
      ids = find(strcmp(users, username));
      if ~isempty(ids)
        default = passw{ids(1)};
      else
        default = [];
      end

      password = ask('Password', '');

      if cm.cachePassword == 2
        cache    = ask('Store credentials', 'n');
        if ~isempty(cache) && cache(1) == 'y'
          cache = true;
        else
          cache = false;
        end
      else
        cache = logical(cm.cachePassword);
      end
    end


    function [hostname, database, username] = selectDatabase(cm)
      % this is a stubb

      for kk = 1:numel(cm.credentials)
        fprintf('% 2d.  %s\n', char(cm.credentials{kk}));
      end
      fprintf('%d. NEW (default)\n', numel(cm.credentials)+1);
      str = input('Select connection:  ', 's');
      if isempty(str)
        id = numel(cm.credentials)+1;
      else
        id = eval(str);
      end
      if id > numel(cm.credentials)
        hostname = input('Hostname:  ', 's');
        database = input('Database:  ', 's');
        username = [];
      else
        hostname = cm.credentials{kk}.hostname;
        database = cm.credentials{kk}.database;
        username = cm.credentials{kk}.username;
      end
    end

  end % private methods

end % classdef


% this should become utils.jmysql.connect
function conn = connect(hostname, database, username, password)

  % informative message
  import utils.const.*
  utils.helper.msg(msg.PROC1, 'connection to mysql://%s/%s username=%s', hostname, database, username);

  % connection credential
  uri = sprintf('jdbc:mysql://%s/%s', hostname, database);
  db = javaObject('com.mysql.jdbc.Driver');
  pl = javaObject('java.util.Properties');
  pl.setProperty(db.USER_PROPERTY_KEY, username);
  pl.setProperty(db.PASSWORD_PROPERTY_KEY, password);

  try
    % connect
    conn = db.connect(uri, pl);
  catch ex
    % haven't decided yet if this code should be here or higher in the stack
    if strcmp(ex.identifier, 'MATLAB:Java:GenericException')
      % exceptions handling in matlab sucks
      if ~isempty(strfind(ex.message, 'java.sql.SQLException: Access denied'))
        throw(MException('utils:jmysql:connect:AccessDenied', '### access denied').addCause(ex));
      end
    end
    rethrow(ex);
  end
end


function str = ask(msg, default)
  str = input(sprintf('%s (default: %s):  ', msg, default), 's');
  if isempty(str)
    str = default;
  end
  if ~ischar(str)
    str = char(str);
  end
end

function str = choose(msg, choices, default)
  options = sprintf('%s, ', choices{:});
  options = options(1:end-2);
  str = input(sprintf('%s (options: %s, default: %s):  ', msg, options, default), 's');
  if isempty(str)
    str = default;
  end
  if ~ischar(str)
    str = char(str);
  end
end