summaryrefslogtreecommitdiff
path: root/opendc/util
diff options
context:
space:
mode:
authorGeorgios Andreadis <g.andreadis@student.tudelft.nl>2017-09-25 13:50:49 +0200
committerGeorgios Andreadis <g.andreadis@student.tudelft.nl>2017-09-25 13:50:49 +0200
commita1589e75358558eada7ffc2efc7e3fa7160d233e (patch)
tree7889a2364292cd8b90fe996da7907bebf200d3dc /opendc/util
parent1f34466d41ba01a3dd36b0866696367d397daf7e (diff)
Reformat codebase and fix spelling errors
Diffstat (limited to 'opendc/util')
-rw-r--r--opendc/util/database.py17
-rw-r--r--opendc/util/exceptions.py11
-rw-r--r--opendc/util/parameter_checker.py21
-rw-r--r--opendc/util/path_parser.py4
-rw-r--r--opendc/util/rest.py24
5 files changed, 48 insertions, 29 deletions
diff --git a/opendc/util/database.py b/opendc/util/database.py
index 32aa947c..e4c257d5 100644
--- a/opendc/util/database.py
+++ b/opendc/util/database.py
@@ -1,7 +1,6 @@
-from datetime import datetime
import json
-import sqlite3
import sys
+from datetime import datetime
from mysql.connector.pooling import MySQLConnectionPool
@@ -12,10 +11,12 @@ with open(sys.argv[1]) as file:
DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S'
CONNECTION_POOL = None
+
def init_connection_pool(user, password, database, host, port):
global CONNECTION_POOL
- CONNECTION_POOL = MySQLConnectionPool(pool_name = "opendcpool", pool_size = 5,
- user=user, password=password, database=database, host=host, port=port)
+ CONNECTION_POOL = MySQLConnectionPool(pool_name="opendcpool", pool_size=5,
+ user=user, password=password, database=database, host=host, port=port)
+
def execute(statement, t):
"""Open a database connection and execute the statement."""
@@ -23,7 +24,7 @@ def execute(statement, t):
# Connect to the database
connection = CONNECTION_POOL.get_connection()
cursor = connection.cursor()
-
+
# Execute the statement
cursor.execute(statement, t)
@@ -38,6 +39,7 @@ def execute(statement, t):
# Return the id
return row_id
+
def fetchone(statement, t=None):
"""Open a database connection and return the first row matched by the SELECT statement."""
@@ -58,6 +60,7 @@ def fetchone(statement, t=None):
connection.close()
return value
+
def fetchall(statement, t=None):
"""Open a database connection and return all rows matched by the SELECT statement."""
@@ -66,7 +69,7 @@ def fetchall(statement, t=None):
cursor = connection.cursor()
# Execute the SELECT statement
-
+
if t is not None:
cursor.execute(statement, t)
else:
@@ -78,11 +81,13 @@ def fetchall(statement, t=None):
connection.close()
return values
+
def datetime_to_string(datetime_to_convert):
"""Return a database-compatible string representation of the given datetime object."""
return datetime_to_convert.strftime(DATETIME_STRING_FORMAT)
+
def string_to_datetime(string_to_convert):
"""Return a datetime corresponding to the given string representation."""
diff --git a/opendc/util/exceptions.py b/opendc/util/exceptions.py
index 56a04ab9..8eea268a 100644
--- a/opendc/util/exceptions.py
+++ b/opendc/util/exceptions.py
@@ -1,24 +1,30 @@
class RequestInitializationError(Exception):
"""Raised when a Request cannot successfully be initialized"""
+
class UnimplementedEndpointError(RequestInitializationError):
"""Raised when a Request path does not point to a module."""
+
class MissingRequestParameterError(RequestInitializationError):
"""Raised when a Request does not contain one or more required parameters."""
+
class UnsupportedMethodError(RequestInitializationError):
"""Raised when a Request does not use a supported REST method.
The method must be in all-caps, supported by REST, and implemented by the module.
"""
+
class AuthorizationTokenError(RequestInitializationError):
"""Raised when an authorization token is not correctly verified."""
+
class ForeignKeyError(Exception):
"""Raised when a foreign key constraint is not met."""
+
class RowNotFoundError(Exception):
"""Raised when a database row is not found."""
@@ -29,8 +35,10 @@ class RowNotFoundError(Exception):
self.table_name = table_name
+
class ParameterError(Exception):
- """Raised when a paramter is either missing or incorrectly typed."""
+ """Raised when a parameter is either missing or incorrectly typed."""
+
class IncorrectParameterError(ParameterError):
"""Raised when a parameter is of the wrong type."""
@@ -46,6 +54,7 @@ class IncorrectParameterError(ParameterError):
self.parameter_name = parameter_name
self.parameter_location = parameter_location
+
class MissingParameterError(ParameterError):
"""Raised when a parameter is missing."""
diff --git a/opendc/util/parameter_checker.py b/opendc/util/parameter_checker.py
index 32cd6777..5188e56a 100644
--- a/opendc/util/parameter_checker.py
+++ b/opendc/util/parameter_checker.py
@@ -1,21 +1,22 @@
from opendc.util import database, exceptions
+
def _missing_parameter(params_required, params_actual, parent=''):
"""Recursively search for the first missing parameter."""
for param_name in params_required:
-
+
if not param_name in params_actual:
return '{}.{}'.format(parent, param_name)
param_required = params_required.get(param_name)
- param_actual = params_actual.get(param_name)
+ param_actual = params_actual.get(param_name)
if isinstance(param_required, dict):
-
+
param_missing = _missing_parameter(
- param_required,
- param_actual,
+ param_required,
+ param_actual,
param_name
)
@@ -24,13 +25,14 @@ def _missing_parameter(params_required, params_actual, parent=''):
return None
+
def _incorrect_parameter(params_required, params_actual, parent=''):
"""Recursively make sure each parameter is of the correct type."""
for param_name in params_required:
param_required = params_required.get(param_name)
- param_actual = params_actual.get(param_name)
+ param_actual = params_actual.get(param_name)
if isinstance(param_required, dict):
@@ -60,6 +62,7 @@ def _incorrect_parameter(params_required, params_actual, parent=''):
if param_required.startswith('list') and not isinstance(param_actual, list):
return '{}.{}'.format(parent, param_name)
+
def _format_parameter(parameter):
"""Format the output of a parameter check."""
@@ -67,11 +70,12 @@ def _format_parameter(parameter):
inner = ['["{}"]'.format(x) for x in parts[2:]]
return parts[1] + ''.join(inner)
+
def check(request, **kwargs):
"""Return True if all required parameters are there."""
for location, params_required in kwargs.iteritems():
-
+
params_actual = getattr(request, 'params_{}'.format(location))
missing_parameter = _missing_parameter(params_required, params_actual)
@@ -79,7 +83,7 @@ def check(request, **kwargs):
raise exceptions.MissingParameterError(
_format_parameter(missing_parameter),
location
- )
+ )
incorrect_parameter = _incorrect_parameter(params_required, params_actual)
if incorrect_parameter is not None:
@@ -87,4 +91,3 @@ def check(request, **kwargs):
_format_parameter(incorrect_parameter),
location
)
-
diff --git a/opendc/util/path_parser.py b/opendc/util/path_parser.py
index 292b747b..7948ee1b 100644
--- a/opendc/util/path_parser.py
+++ b/opendc/util/path_parser.py
@@ -1,6 +1,6 @@
import json
-import sys, os
-import re
+import os
+
def parse(version, endpoint_path):
"""Map an HTTP endpoint path to an API path"""
diff --git a/opendc/util/rest.py b/opendc/util/rest.py
index ad53f084..7cf2d0b3 100644
--- a/opendc/util/rest.py
+++ b/opendc/util/rest.py
@@ -1,6 +1,5 @@
import importlib
import json
-import os
import sys
from oauth2client import client, crypt
@@ -10,6 +9,7 @@ from opendc.util import exceptions, parameter_checker
with open(sys.argv[1]) as file:
KEYS = json.load(file)
+
class Request(object):
"""WebSocket message to REST request mapping."""
@@ -23,16 +23,16 @@ class Request(object):
try:
self.message = message
-
+
self.id = message['id']
-
+
self.path = message['path']
self.method = message['method']
-
+
self.params_body = message['parameters']['body']
self.params_path = message['parameters']['path']
self.params_query = message['parameters']['query']
-
+
self.token = message['token']
except KeyError as exception:
@@ -40,9 +40,9 @@ class Request(object):
# Parse the path and import the appropriate module
- try:
+ try:
self.path = message['path'].encode('ascii', 'ignore').strip('/')
-
+
module_base = 'opendc.api.{}.endpoint'
module_path = self.path.translate(None, '{}').replace('/', '.')
@@ -62,10 +62,11 @@ class Request(object):
raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method))
if not hasattr(self.module, self.method):
- raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format(self.path, self.method))
+ raise exceptions.UnsupportedMethodError(
+ 'Unimplemented method at endpoint {}: {}'.format(self.path, self.method))
# Verify the user
-
+
try:
self.google_id = self._verify_token(self.token)
@@ -87,7 +88,7 @@ class Request(object):
raise crypt.AppIdentityError('Unrecognized client.')
if idinfo['iss'] not in ['accounts.google.com', 'https://accounts.google.com']:
- raise crypt.AppIdentityError('Wrong issuer.')
+ raise crypt.AppIdentityError('Wrong issuer.')
return idinfo['sub']
@@ -114,6 +115,7 @@ class Request(object):
return json.dumps(self.message)
+
class Response(object):
"""Response to websocket mapping"""
@@ -125,7 +127,7 @@ class Response(object):
'description': status_description
}
self.content = content
-
+
def to_JSON(self):
""""Return a JSON representation of this Response"""