diff options
Diffstat (limited to 'opendc/util')
| -rw-r--r-- | opendc/util/database.py | 17 | ||||
| -rw-r--r-- | opendc/util/exceptions.py | 11 | ||||
| -rw-r--r-- | opendc/util/parameter_checker.py | 21 | ||||
| -rw-r--r-- | opendc/util/path_parser.py | 4 | ||||
| -rw-r--r-- | opendc/util/rest.py | 24 |
5 files changed, 48 insertions, 29 deletions
diff --git a/opendc/util/database.py b/opendc/util/database.py index 32aa947c..e4c257d5 100644 --- a/opendc/util/database.py +++ b/opendc/util/database.py @@ -1,7 +1,6 @@ -from datetime import datetime import json -import sqlite3 import sys +from datetime import datetime from mysql.connector.pooling import MySQLConnectionPool @@ -12,10 +11,12 @@ with open(sys.argv[1]) as file: DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S' CONNECTION_POOL = None + def init_connection_pool(user, password, database, host, port): global CONNECTION_POOL - CONNECTION_POOL = MySQLConnectionPool(pool_name = "opendcpool", pool_size = 5, - user=user, password=password, database=database, host=host, port=port) + CONNECTION_POOL = MySQLConnectionPool(pool_name="opendcpool", pool_size=5, + user=user, password=password, database=database, host=host, port=port) + def execute(statement, t): """Open a database connection and execute the statement.""" @@ -23,7 +24,7 @@ def execute(statement, t): # Connect to the database connection = CONNECTION_POOL.get_connection() cursor = connection.cursor() - + # Execute the statement cursor.execute(statement, t) @@ -38,6 +39,7 @@ def execute(statement, t): # Return the id return row_id + def fetchone(statement, t=None): """Open a database connection and return the first row matched by the SELECT statement.""" @@ -58,6 +60,7 @@ def fetchone(statement, t=None): connection.close() return value + def fetchall(statement, t=None): """Open a database connection and return all rows matched by the SELECT statement.""" @@ -66,7 +69,7 @@ def fetchall(statement, t=None): cursor = connection.cursor() # Execute the SELECT statement - + if t is not None: cursor.execute(statement, t) else: @@ -78,11 +81,13 @@ def fetchall(statement, t=None): connection.close() return values + def datetime_to_string(datetime_to_convert): """Return a database-compatible string representation of the given datetime object.""" return datetime_to_convert.strftime(DATETIME_STRING_FORMAT) + def string_to_datetime(string_to_convert): """Return a datetime corresponding to the given string representation.""" diff --git a/opendc/util/exceptions.py b/opendc/util/exceptions.py index 56a04ab9..8eea268a 100644 --- a/opendc/util/exceptions.py +++ b/opendc/util/exceptions.py @@ -1,24 +1,30 @@ class RequestInitializationError(Exception): """Raised when a Request cannot successfully be initialized""" + class UnimplementedEndpointError(RequestInitializationError): """Raised when a Request path does not point to a module.""" + class MissingRequestParameterError(RequestInitializationError): """Raised when a Request does not contain one or more required parameters.""" + class UnsupportedMethodError(RequestInitializationError): """Raised when a Request does not use a supported REST method. The method must be in all-caps, supported by REST, and implemented by the module. """ + class AuthorizationTokenError(RequestInitializationError): """Raised when an authorization token is not correctly verified.""" + class ForeignKeyError(Exception): """Raised when a foreign key constraint is not met.""" + class RowNotFoundError(Exception): """Raised when a database row is not found.""" @@ -29,8 +35,10 @@ class RowNotFoundError(Exception): self.table_name = table_name + class ParameterError(Exception): - """Raised when a paramter is either missing or incorrectly typed.""" + """Raised when a parameter is either missing or incorrectly typed.""" + class IncorrectParameterError(ParameterError): """Raised when a parameter is of the wrong type.""" @@ -46,6 +54,7 @@ class IncorrectParameterError(ParameterError): self.parameter_name = parameter_name self.parameter_location = parameter_location + class MissingParameterError(ParameterError): """Raised when a parameter is missing.""" diff --git a/opendc/util/parameter_checker.py b/opendc/util/parameter_checker.py index 32cd6777..5188e56a 100644 --- a/opendc/util/parameter_checker.py +++ b/opendc/util/parameter_checker.py @@ -1,21 +1,22 @@ from opendc.util import database, exceptions + def _missing_parameter(params_required, params_actual, parent=''): """Recursively search for the first missing parameter.""" for param_name in params_required: - + if not param_name in params_actual: return '{}.{}'.format(parent, param_name) param_required = params_required.get(param_name) - param_actual = params_actual.get(param_name) + param_actual = params_actual.get(param_name) if isinstance(param_required, dict): - + param_missing = _missing_parameter( - param_required, - param_actual, + param_required, + param_actual, param_name ) @@ -24,13 +25,14 @@ def _missing_parameter(params_required, params_actual, parent=''): return None + def _incorrect_parameter(params_required, params_actual, parent=''): """Recursively make sure each parameter is of the correct type.""" for param_name in params_required: param_required = params_required.get(param_name) - param_actual = params_actual.get(param_name) + param_actual = params_actual.get(param_name) if isinstance(param_required, dict): @@ -60,6 +62,7 @@ def _incorrect_parameter(params_required, params_actual, parent=''): if param_required.startswith('list') and not isinstance(param_actual, list): return '{}.{}'.format(parent, param_name) + def _format_parameter(parameter): """Format the output of a parameter check.""" @@ -67,11 +70,12 @@ def _format_parameter(parameter): inner = ['["{}"]'.format(x) for x in parts[2:]] return parts[1] + ''.join(inner) + def check(request, **kwargs): """Return True if all required parameters are there.""" for location, params_required in kwargs.iteritems(): - + params_actual = getattr(request, 'params_{}'.format(location)) missing_parameter = _missing_parameter(params_required, params_actual) @@ -79,7 +83,7 @@ def check(request, **kwargs): raise exceptions.MissingParameterError( _format_parameter(missing_parameter), location - ) + ) incorrect_parameter = _incorrect_parameter(params_required, params_actual) if incorrect_parameter is not None: @@ -87,4 +91,3 @@ def check(request, **kwargs): _format_parameter(incorrect_parameter), location ) - diff --git a/opendc/util/path_parser.py b/opendc/util/path_parser.py index 292b747b..7948ee1b 100644 --- a/opendc/util/path_parser.py +++ b/opendc/util/path_parser.py @@ -1,6 +1,6 @@ import json -import sys, os -import re +import os + def parse(version, endpoint_path): """Map an HTTP endpoint path to an API path""" diff --git a/opendc/util/rest.py b/opendc/util/rest.py index ad53f084..7cf2d0b3 100644 --- a/opendc/util/rest.py +++ b/opendc/util/rest.py @@ -1,6 +1,5 @@ import importlib import json -import os import sys from oauth2client import client, crypt @@ -10,6 +9,7 @@ from opendc.util import exceptions, parameter_checker with open(sys.argv[1]) as file: KEYS = json.load(file) + class Request(object): """WebSocket message to REST request mapping.""" @@ -23,16 +23,16 @@ class Request(object): try: self.message = message - + self.id = message['id'] - + self.path = message['path'] self.method = message['method'] - + self.params_body = message['parameters']['body'] self.params_path = message['parameters']['path'] self.params_query = message['parameters']['query'] - + self.token = message['token'] except KeyError as exception: @@ -40,9 +40,9 @@ class Request(object): # Parse the path and import the appropriate module - try: + try: self.path = message['path'].encode('ascii', 'ignore').strip('/') - + module_base = 'opendc.api.{}.endpoint' module_path = self.path.translate(None, '{}').replace('/', '.') @@ -62,10 +62,11 @@ class Request(object): raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method)) if not hasattr(self.module, self.method): - raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format(self.path, self.method)) + raise exceptions.UnsupportedMethodError( + 'Unimplemented method at endpoint {}: {}'.format(self.path, self.method)) # Verify the user - + try: self.google_id = self._verify_token(self.token) @@ -87,7 +88,7 @@ class Request(object): raise crypt.AppIdentityError('Unrecognized client.') if idinfo['iss'] not in ['accounts.google.com', 'https://accounts.google.com']: - raise crypt.AppIdentityError('Wrong issuer.') + raise crypt.AppIdentityError('Wrong issuer.') return idinfo['sub'] @@ -114,6 +115,7 @@ class Request(object): return json.dumps(self.message) + class Response(object): """Response to websocket mapping""" @@ -125,7 +127,7 @@ class Response(object): 'description': status_description } self.content = content - + def to_JSON(self): """"Return a JSON representation of this Response""" |
