From 2281d3265423d01e60f8cc088de5a5730bb8a910 Mon Sep 17 00:00:00 2001 From: Fabian Mastenbroek Date: Sat, 15 May 2021 13:09:06 +0200 Subject: api: Migrate to Flask Restful This change updates the API to use Flask Restful instead of our own in-house REST library. This change reduces the maintenance effort and allows us to drastically simplify the API implementation needed for the OpenDC v2 API. --- opendc-web/opendc-web-api/opendc/util/__init__.py | 0 opendc-web/opendc-web-api/opendc/util/auth.py | 253 --------------------- opendc-web/opendc-web-api/opendc/util/database.py | 77 ------- .../opendc-web-api/opendc/util/exceptions.py | 64 ------ opendc-web/opendc-web-api/opendc/util/json.py | 12 - .../opendc/util/parameter_checker.py | 85 ------- .../opendc-web-api/opendc/util/path_parser.py | 36 --- opendc-web/opendc-web-api/opendc/util/rest.py | 109 --------- 8 files changed, 636 deletions(-) delete mode 100644 opendc-web/opendc-web-api/opendc/util/__init__.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/auth.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/database.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/exceptions.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/json.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/parameter_checker.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/path_parser.py delete mode 100644 opendc-web/opendc-web-api/opendc/util/rest.py (limited to 'opendc-web/opendc-web-api/opendc/util') diff --git a/opendc-web/opendc-web-api/opendc/util/__init__.py b/opendc-web/opendc-web-api/opendc/util/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/opendc-web/opendc-web-api/opendc/util/auth.py b/opendc-web/opendc-web-api/opendc/util/auth.py deleted file mode 100644 index 810b582a..00000000 --- a/opendc-web/opendc-web-api/opendc/util/auth.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright (c) 2021 AtLarge Research -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -import json -import time -from functools import wraps - -import urllib3 -from flask import request, _request_ctx_stack -from jose import jwt, JWTError -from werkzeug.local import LocalProxy - -current_user = LocalProxy(lambda: getattr(_request_ctx_stack.top, 'current_user', None)) - - -class AuthError(Exception): - """ - This error is thrown when the request failed to authorize. - """ - - def __init__(self, error, status_code): - Exception.__init__(self, error) - self.error = error - self.status_code = status_code - - -class AuthManager: - """ - This class handles the authorization of requests. - """ - - def __init__(self, alg, issuer, audience): - self._alg = alg - self._issuer = issuer - self._audience = audience - - def require(self, f): - """Determines if the Access Token is valid - """ - - @wraps(f) - def decorated(*args, **kwargs): - token = _get_token() - try: - header = jwt.get_unverified_header(token) - except JWTError as e: - raise AuthError({"code": "invalid_token", - "description": str(e)}, 401) - - alg = header.get('alg', None) - if alg != self._alg.algorithm: - raise AuthError({"code": "invalid_header", - "description": f"Signature algorithm of {alg} is not supported. Expected the ID token " - f"to be signed with {self._alg.algorithm}"}, 401) - - kid = header.get('kid', None) - try: - secret_or_certificate = self._alg.get_key(key_id=kid) - except TokenValidationError as e: - raise AuthError({"code": "invalid_header", - "description": str(e)}, 401) - try: - payload = jwt.decode(token, - key=secret_or_certificate, - algorithms=[self._alg.algorithm], - audience=self._audience, - issuer=self._issuer) - _request_ctx_stack.top.current_user = payload - return f(*args, **kwargs) - except jwt.ExpiredSignatureError: - raise AuthError({"code": "token_expired", - "description": "token is expired"}, 401) - except jwt.JWTClaimsError: - raise AuthError({"code": "invalid_claims", - "description": - "incorrect claims," - "please check the audience and issuer"}, 401) - except Exception as e: - print(e) - raise AuthError({"code": "invalid_header", - "description": - "Unable to parse authentication" - " token."}, 401) - - return decorated - - -def _get_token(): - """ - Obtain the Access Token from the Authorization Header - """ - auth = request.headers.get("Authorization", None) - if not auth: - raise AuthError({"code": "authorization_header_missing", - "description": - "Authorization header is expected"}, 401) - - parts = auth.split() - - if parts[0].lower() != "bearer": - raise AuthError({"code": "invalid_header", - "description": - "Authorization header must start with" - " Bearer"}, 401) - if len(parts) == 1: - raise AuthError({"code": "invalid_header", - "description": "Token not found"}, 401) - if len(parts) > 2: - raise AuthError({"code": "invalid_header", - "description": - "Authorization header must be" - " Bearer token"}, 401) - - token = parts[1] - return token - - -class SymmetricJwtAlgorithm: - """Verifier for HMAC signatures, which rely on shared secrets. - Args: - shared_secret (str): The shared secret used to decode the token. - algorithm (str, optional): The expected signing algorithm. Defaults to "HS256". - """ - - def __init__(self, shared_secret, algorithm="HS256"): - self.algorithm = algorithm - self._shared_secret = shared_secret - - # pylint: disable=W0613 - def get_key(self, key_id=None): - """ - Obtain the key for this algorithm. - :param key_id: The identifier of the key. - :return: The JWK key. - """ - return self._shared_secret - - -class AsymmetricJwtAlgorithm: - """Verifier for RSA signatures, which rely on public key certificates. - Args: - jwks_url (str): The url where the JWK set is located. - algorithm (str, optional): The expected signing algorithm. Defaults to "RS256". - """ - - def __init__(self, jwks_url, algorithm="RS256"): - self.algorithm = algorithm - self._fetcher = JwksFetcher(jwks_url) - - def get_key(self, key_id=None): - """ - Obtain the key for this algorithm. - :param key_id: The identifier of the key. - :return: The JWK key. - """ - return self._fetcher.get_key(key_id) - - -class TokenValidationError(Exception): - """ - Error thrown when the token cannot be validated - """ - - -class JwksFetcher: - """Class that fetches and holds a JSON web key set. - This class makes use of an in-memory cache. For it to work properly, define this instance once and re-use it. - Args: - jwks_url (str): The url where the JWK set is located. - cache_ttl (str, optional): The lifetime of the JWK set cache in seconds. Defaults to 600 seconds. - """ - CACHE_TTL = 600 # 10 min cache lifetime - - def __init__(self, jwks_url, cache_ttl=CACHE_TTL): - self._jwks_url = jwks_url - self._http = urllib3.PoolManager() - self._cache_value = {} - self._cache_date = 0 - self._cache_ttl = cache_ttl - self._cache_is_fresh = False - - def _fetch_jwks(self, force=False): - """Attempts to obtain the JWK set from the cache, as long as it's still valid. - When not, it will perform a network request to the jwks_url to obtain a fresh result - and update the cache value with it. - Args: - force (bool, optional): whether to ignore the cache and force a network request or not. Defaults to False. - """ - has_expired = self._cache_date + self._cache_ttl < time.time() - - if not force and not has_expired: - # Return from cache - self._cache_is_fresh = False - return self._cache_value - - # Invalidate cache and fetch fresh data - self._cache_value = {} - response = self._http.request('GET', self._jwks_url) - - if response.status == 200: - # Update cache - jwks = json.loads(response.data.decode('utf-8')) - self._cache_value = self._parse_jwks(jwks) - self._cache_is_fresh = True - self._cache_date = time.time() - return self._cache_value - - @staticmethod - def _parse_jwks(jwks): - """Converts a JWK string representation into a binary certificate in PEM format. - """ - keys = {} - - for key in jwks['keys']: - keys[key["kid"]] = key - return keys - - def get_key(self, key_id): - """Obtains the JWK associated with the given key id. - Args: - key_id (str): The id of the key to fetch. - Returns: - the JWK associated with the given key id. - - Raises: - TokenValidationError: when a key with that id cannot be found - """ - keys = self._fetch_jwks() - - if keys and key_id in keys: - return keys[key_id] - - if not self._cache_is_fresh: - keys = self._fetch_jwks(force=True) - if keys and key_id in keys: - return keys[key_id] - raise TokenValidationError(f"RSA Public Key with ID {key_id} was not found.") diff --git a/opendc-web/opendc-web-api/opendc/util/database.py b/opendc-web/opendc-web-api/opendc/util/database.py deleted file mode 100644 index dd26533d..00000000 --- a/opendc-web/opendc-web-api/opendc/util/database.py +++ /dev/null @@ -1,77 +0,0 @@ -import urllib.parse -from datetime import datetime - -from pymongo import MongoClient - -DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S' -CONNECTION_POOL = None - - -class Database: - """Object holding functionality for database access.""" - def __init__(self): - self.opendc_db = None - - def initialize_database(self, user, password, database, host): - """Initializes the database connection.""" - - user = urllib.parse.quote_plus(user) - password = urllib.parse.quote_plus(password) - database = urllib.parse.quote_plus(database) - host = urllib.parse.quote_plus(host) - - client = MongoClient('mongodb://%s:%s@%s/default_db?authSource=%s' % (user, password, host, database)) - self.opendc_db = client.opendc - - def fetch_one(self, query, collection): - """Uses existing mongo connection to return a single (the first) document in a collection matching the given - query as a JSON object. - - The query needs to be in json format, i.e.: `{'name': prefab_name}`. - """ - return getattr(self.opendc_db, collection).find_one(query) - - def fetch_all(self, query, collection): - """Uses existing mongo connection to return all documents matching a given query, as a list of JSON objects. - - The query needs to be in json format, i.e.: `{'name': prefab_name}`. - """ - cursor = getattr(self.opendc_db, collection).find(query) - return list(cursor) - - def insert(self, obj, collection): - """Updates an existing object.""" - bson = getattr(self.opendc_db, collection).insert(obj) - - return bson - - def update(self, _id, obj, collection): - """Updates an existing object.""" - return getattr(self.opendc_db, collection).update({'_id': _id}, obj) - - def delete_one(self, query, collection): - """Deletes one object matching the given query. - - The query needs to be in json format, i.e.: `{'name': prefab_name}`. - """ - getattr(self.opendc_db, collection).delete_one(query) - - def delete_all(self, query, collection): - """Deletes all objects matching the given query. - - The query needs to be in json format, i.e.: `{'name': prefab_name}`. - """ - getattr(self.opendc_db, collection).delete_many(query) - - @staticmethod - def datetime_to_string(datetime_to_convert): - """Return a database-compatible string representation of the given datetime object.""" - return datetime_to_convert.strftime(DATETIME_STRING_FORMAT) - - @staticmethod - def string_to_datetime(string_to_convert): - """Return a datetime corresponding to the given string representation.""" - return datetime.strptime(string_to_convert, DATETIME_STRING_FORMAT) - - -DB = Database() diff --git a/opendc-web/opendc-web-api/opendc/util/exceptions.py b/opendc-web/opendc-web-api/opendc/util/exceptions.py deleted file mode 100644 index 7724a407..00000000 --- a/opendc-web/opendc-web-api/opendc/util/exceptions.py +++ /dev/null @@ -1,64 +0,0 @@ -class RequestInitializationError(Exception): - """Raised when a Request cannot successfully be initialized""" - - -class UnimplementedEndpointError(RequestInitializationError): - """Raised when a Request path does not point to a module.""" - - -class MissingRequestParameterError(RequestInitializationError): - """Raised when a Request does not contain one or more required parameters.""" - - -class UnsupportedMethodError(RequestInitializationError): - """Raised when a Request does not use a supported REST method. - - The method must be in all-caps, supported by REST, and implemented by the module. - """ - - -class AuthorizationTokenError(RequestInitializationError): - """Raised when an authorization token is not correctly verified.""" - - -class ForeignKeyError(Exception): - """Raised when a foreign key constraint is not met.""" - - -class RowNotFoundError(Exception): - """Raised when a database row is not found.""" - def __init__(self, table_name): - super(RowNotFoundError, self).__init__('Row in `{}` table not found.'.format(table_name)) - - self.table_name = table_name - - -class ParameterError(Exception): - """Raised when a parameter is either missing or incorrectly typed.""" - - -class IncorrectParameterError(ParameterError): - """Raised when a parameter is of the wrong type.""" - def __init__(self, parameter_name, parameter_location): - super(IncorrectParameterError, - self).__init__('Incorrectly typed `{}` {} parameter.'.format(parameter_name, parameter_location)) - - self.parameter_name = parameter_name - self.parameter_location = parameter_location - - -class MissingParameterError(ParameterError): - """Raised when a parameter is missing.""" - def __init__(self, parameter_name, parameter_location): - super(MissingParameterError, - self).__init__('Missing required `{}` {} parameter.'.format(parameter_name, parameter_location)) - - self.parameter_name = parameter_name - self.parameter_location = parameter_location - - -class ClientError(Exception): - """Raised when a 4xx response is to be returned.""" - def __init__(self, response): - super(ClientError, self).__init__(str(response)) - self.response = response diff --git a/opendc-web/opendc-web-api/opendc/util/json.py b/opendc-web/opendc-web-api/opendc/util/json.py deleted file mode 100644 index 2ef4f965..00000000 --- a/opendc-web/opendc-web-api/opendc/util/json.py +++ /dev/null @@ -1,12 +0,0 @@ -import flask -from bson.objectid import ObjectId - - -class JSONEncoder(flask.json.JSONEncoder): - """ - A customized JSON encoder to handle unsupported types. - """ - def default(self, o): - if isinstance(o, ObjectId): - return str(o) - return flask.json.JSONEncoder.default(self, o) diff --git a/opendc-web/opendc-web-api/opendc/util/parameter_checker.py b/opendc-web/opendc-web-api/opendc/util/parameter_checker.py deleted file mode 100644 index 14dd1dc0..00000000 --- a/opendc-web/opendc-web-api/opendc/util/parameter_checker.py +++ /dev/null @@ -1,85 +0,0 @@ -from opendc.util import exceptions -from opendc.util.database import Database - - -def _missing_parameter(params_required, params_actual, parent=''): - """Recursively search for the first missing parameter.""" - - for param_name in params_required: - - if param_name not in params_actual: - return '{}.{}'.format(parent, param_name) - - param_required = params_required.get(param_name) - param_actual = params_actual.get(param_name) - - if isinstance(param_required, dict): - - param_missing = _missing_parameter(param_required, param_actual, param_name) - - if param_missing is not None: - return '{}.{}'.format(parent, param_missing) - - return None - - -def _incorrect_parameter(params_required, params_actual, parent=''): - """Recursively make sure each parameter is of the correct type.""" - - for param_name in params_required: - - param_required = params_required.get(param_name) - param_actual = params_actual.get(param_name) - - if isinstance(param_required, dict): - - param_incorrect = _incorrect_parameter(param_required, param_actual, param_name) - - if param_incorrect is not None: - return '{}.{}'.format(parent, param_incorrect) - - else: - - if param_required == 'datetime': - try: - Database.string_to_datetime(param_actual) - except: - return '{}.{}'.format(parent, param_name) - - type_pairs = [ - ('int', (int,)), - ('float', (float, int)), - ('bool', (bool,)), - ('string', (str, int)), - ('list', (list,)), - ] - - for str_type, actual_types in type_pairs: - if param_required == str_type and all(not isinstance(param_actual, t) - for t in actual_types): - return '{}.{}'.format(parent, param_name) - - return None - - -def _format_parameter(parameter): - """Format the output of a parameter check.""" - - parts = parameter.split('.') - inner = ['["{}"]'.format(x) for x in parts[2:]] - return parts[1] + ''.join(inner) - - -def check(request, **kwargs): - """Check if all required parameters are there.""" - - for location, params_required in kwargs.items(): - params_actual = getattr(request, 'params_{}'.format(location)) - - missing_parameter = _missing_parameter(params_required, params_actual) - if missing_parameter is not None: - raise exceptions.MissingParameterError(_format_parameter(missing_parameter), location) - - incorrect_parameter = _incorrect_parameter(params_required, params_actual) - if incorrect_parameter is not None: - raise exceptions.IncorrectParameterError(_format_parameter(incorrect_parameter), location) diff --git a/opendc-web/opendc-web-api/opendc/util/path_parser.py b/opendc-web/opendc-web-api/opendc/util/path_parser.py deleted file mode 100644 index c8452f20..00000000 --- a/opendc-web/opendc-web-api/opendc/util/path_parser.py +++ /dev/null @@ -1,36 +0,0 @@ -import json -import os - - -def parse(version, endpoint_path): - """Map an HTTP endpoint path to an API path""" - - # Get possible paths - with open(os.path.join(os.path.dirname(__file__), '..', 'api', '{}', 'paths.json').format(version)) as paths_file: - paths = json.load(paths_file) - - # Find API path that matches endpoint_path - endpoint_path_parts = endpoint_path.strip('/').split('/') - paths_parts = [x.strip('/').split('/') for x in paths if len(x.strip('/').split('/')) == len(endpoint_path_parts)] - path = None - - for path_parts in paths_parts: - found = True - for (endpoint_part, part) in zip(endpoint_path_parts, path_parts): - if not part.startswith('{') and endpoint_part != part: - found = False - break - if found: - path = path_parts - - if path is None: - return None - - # Extract path parameters - parameters = {} - - for (name, value) in zip(path, endpoint_path_parts): - if name.startswith('{'): - parameters[name.strip('{}')] = value - - return '{}/{}'.format(version, '/'.join(path)), parameters diff --git a/opendc-web/opendc-web-api/opendc/util/rest.py b/opendc-web/opendc-web-api/opendc/util/rest.py deleted file mode 100644 index 63d063b3..00000000 --- a/opendc-web/opendc-web-api/opendc/util/rest.py +++ /dev/null @@ -1,109 +0,0 @@ -import importlib -import json - -from opendc.util import exceptions, parameter_checker -from opendc.util.exceptions import ClientError -from opendc.util.auth import current_user - - -class Request: - """WebSocket message to REST request mapping.""" - def __init__(self, message=None): - """"Initialize a Request from a socket message.""" - - # Get the Request parameters from the message - - if message is None: - return - - try: - self.message = message - - self.id = message['id'] - - self.path = message['path'] - self.method = message['method'] - - self.params_body = message['parameters']['body'] - self.params_path = message['parameters']['path'] - self.params_query = message['parameters']['query'] - - self.token = message['token'] - - except KeyError as exception: - raise exceptions.MissingRequestParameterError(exception) - - # Parse the path and import the appropriate module - - try: - self.path = message['path'].strip('/') - - module_base = 'opendc.api.{}.endpoint' - module_path = self.path.replace('{', '').replace('}', '').replace('/', '.') - - self.module = importlib.import_module(module_base.format(module_path)) - except ImportError as e: - print(e) - raise exceptions.UnimplementedEndpointError('Unimplemented endpoint: {}.'.format(self.path)) - - # Check the method - - if self.method not in ['POST', 'GET', 'PUT', 'PATCH', 'DELETE']: - raise exceptions.UnsupportedMethodError('Non-rest method: {}'.format(self.method)) - - if not hasattr(self.module, self.method): - raise exceptions.UnsupportedMethodError('Unimplemented method at endpoint {}: {}'.format( - self.path, self.method)) - - self.current_user = current_user - - def check_required_parameters(self, **kwargs): - """Raise an error if a parameter is missing or of the wrong type.""" - - try: - parameter_checker.check(self, **kwargs) - except exceptions.ParameterError as e: - raise ClientError(Response(400, str(e))) - - def process(self): - """Process the Request and return a Response.""" - - method = getattr(self.module, self.method) - - try: - response = method(self) - except ClientError as e: - e.response.id = self.id - return e.response - - response.id = self.id - - return response - - def to_JSON(self): - """Return a JSON representation of this Request""" - - self.message['id'] = 0 - self.message['token'] = None - - return json.dumps(self.message) - - -class Response: - """Response to websocket mapping""" - def __init__(self, status_code, status_description, content=None): - """Initialize a new Response.""" - - self.id = 0 - self.status = {'code': status_code, 'description': status_description} - self.content = content - - def to_JSON(self): - """"Return a JSON representation of this Response""" - - data = {'id': self.id, 'status': self.status} - - if self.content is not None: - data['content'] = self.content - - return json.dumps(data, default=str) -- cgit v1.2.3