Commit f82bb35e authored by Carina Antunes's avatar Carina Antunes
Browse files

refactor, error handlers

parent 89065c67
......@@ -23,6 +23,7 @@ from invenio_records_rest.facets import terms_filter
from kombu import Exchange, Queue
from cern_search_rest_api.modules.cernsearch.api import CernSearchRecord
from cern_search_rest_api.modules.cernsearch.errors import BadRequestError
from cern_search_rest_api.modules.cernsearch.facets import regex_aggregation, simple_query_string
from cern_search_rest_api.modules.cernsearch.indexer import CernSearchRecordIndexer
from cern_search_rest_api.modules.cernsearch.permissions import (
......@@ -32,7 +33,10 @@ from cern_search_rest_api.modules.cernsearch.permissions import (
record_read_permission_factory,
record_update_permission_factory,
)
from cern_search_rest_api.modules.cernsearch.views import elasticsearch_version_conflict_engine_exception_handler
from cern_search_rest_api.modules.cernsearch.views import (
elasticsearch_version_conflict_engine_exception_handler,
generic_bad_exception_handler,
)
def _(x):
......@@ -170,6 +174,7 @@ RECORDS_REST_ENDPOINTS = dict(
},
error_handlers={
TransportError: elasticsearch_version_conflict_engine_exception_handler,
BadRequestError: generic_bad_exception_handler,
},
)
)
......
......@@ -7,6 +7,7 @@
# Citadel Search is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""Custom errors."""
from invenio_pidstore.errors import PIDValueError
from invenio_records_rest.errors import PIDRESTException
from invenio_rest.errors import RESTException, RESTValidationError
......@@ -66,5 +67,19 @@ class ConflictError(RESTException):
class PIDAlreadyExistsRESTError(PIDRESTException):
"""Persistent identifier already exists error."""
code = 404
code = 400
description = "PID already exists."
class BadRequestError(Exception):
"""Bad Request Error."""
class PIDSizeTooBigError(BadRequestError, PIDValueError):
"""PID size is too big."""
class BadRequestRESTError(RESTException):
"""Bad Request Error."""
code = 400
......@@ -13,7 +13,7 @@ from cern_search_rest_api.modules.cernsearch.celery import DeclareDeadletter
from cern_search_rest_api.modules.cernsearch.views import (
build_blueprint_record_files_content,
build_ubq_blueprint,
register_extra_routes,
create_csas_blueprint,
)
......@@ -32,7 +32,8 @@ class CERNSearch(object):
ubq_blueprint = build_ubq_blueprint(app)
app.register_blueprint(ubq_blueprint)
register_extra_routes(app)
blueprint_extra_routes = create_csas_blueprint(app)
app.register_blueprint(blueprint_extra_routes)
blueprint_record_files_content = build_blueprint_record_files_content(app)
app.register_blueprint(blueprint_record_files_content)
......
# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2015-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""Default link factories for PID serialization into URLs.
Link factory can be specified as ``links_factory_impl`` in
:data:`invenio_records_rest.config.RECORDS_REST_ENDPOINTS` configuration.
"""
from flask import url_for
from invenio_records_rest import current_records_rest
def external_links_factory(pid, record=None, **kwargs):
"""Create factory for external record links generation.
:param pid: A Persistent Identifier instance.
:returns: Dictionary containing a list of useful links for the record.
"""
endpoint = "invenio_records_rest.{0}_item".format(current_records_rest.default_endpoint_prefixes[pid.pid_type])
links = dict(self=url_for(endpoint, pid_value=pid.pid_value, _external=True))
return links
......@@ -10,11 +10,8 @@
from invenio_records_rest.loaders.marshmallow import marshmallow_loader
from cern_search_rest_api.modules.cernsearch.loaders.json import json_marshmallow_loader
from cern_search_rest_api.modules.cernsearch.marshmallow import CSASRecordSchemaV1
csas_loader = marshmallow_loader(CSASRecordSchemaV1)
external_pid_csas_loader = json_marshmallow_loader(CSASRecordSchemaV1)
__all__ = ("csas_loader", "external_pid_csas_loader")
__all__ = ("csas_loader",)
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# This file is part of CERN Search.
# Copyright (C) 2018-2021 CERN.
#
# Citadel Search is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""Marshmallow loader for record deserialization.
Use marshmallow schema to transform a JSON sent via the REST API from an
external to an internal JSON presentation. The marshmallow schema further
allows for advanced data validation.
"""
from __future__ import absolute_import, print_function
from flask import request
from invenio_records_rest.loaders.marshmallow import MarshmallowErrors
from marshmallow import ValidationError
from marshmallow import __version_info__ as marshmallow_version
def json_marshmallow_loader(schema_class):
"""Marshmallow loader for JSON requests."""
def json_loader():
request_json = request.get_json()
context = {}
if marshmallow_version[0] < 3:
result = schema_class(context=context).load(request_json)
if result.errors:
raise MarshmallowErrors(result.errors)
else:
# From Marshmallow 3 the errors on .load() are being raised rather
# than returned. To adjust this change to our flow we catch these
# errors and reraise them instead.
try:
result = schema_class(context=context).load(request_json)
except ValidationError as error:
raise MarshmallowErrors(error.messages)
return result.data
return json_loader
def json_patch_loader():
"""Vanilla load for json-patch requests."""
return request.get_json(force=True)
......@@ -10,6 +10,10 @@
from invenio_pidstore.models import PIDStatus
from invenio_pidstore.providers.base import BaseProvider
from cern_search_rest_api.modules.cernsearch.errors import PIDSizeTooBigError
MAX_PID_SIZE = 36
class ExternalIdProvider(BaseProvider):
"""Vocabulary identifier provider.
......@@ -44,6 +48,9 @@ class ExternalIdProvider(BaseProvider):
"""
assert "pid_value" in kwargs
if len(kwargs["pid_value"]) > MAX_PID_SIZE:
raise PIDSizeTooBigError(cls.pid_type, kwargs["pid_value"])
kwargs.setdefault("status", cls.default_status)
if object_type and object_uuid:
kwargs["status"] = PIDStatus.REGISTERED
......
......@@ -17,6 +17,7 @@ Custom GET file api to get file content instead of real file.
from __future__ import absolute_import, print_function
import uuid
from collections import defaultdict
from copy import deepcopy
from functools import partial, wraps
from typing import Callable
......@@ -37,9 +38,13 @@ from invenio_records_rest.errors import (
UnsupportedMediaRESTError,
)
from invenio_records_rest.utils import obj_or_import_string
from invenio_records_rest.views import RecordsListResource
from invenio_records_rest.views import create_error_handlers as records_rest_error_handlers
from invenio_records_rest.views import need_record_permission, pass_record, verify_record_permission
from invenio_records_rest.views import (
RecordsListResource,
create_error_handlers,
need_record_permission,
pass_record,
verify_record_permission,
)
from invenio_rest import ContentNegotiatedMethodView
from invenio_search import current_search_client
from six import iteritems
......@@ -48,6 +53,7 @@ from sqlalchemy.exc import SQLAlchemyError
from cern_search_rest_api.modules.cernsearch.api import CernSearchRecord
from cern_search_rest_api.modules.cernsearch.errors import (
BadRequestRESTError,
ConflictError,
Error,
InvalidRecordFormatError,
......@@ -57,6 +63,11 @@ from cern_search_rest_api.modules.cernsearch.indexer import CernSearchRecordInde
from cern_search_rest_api.modules.cernsearch.search import RecordCERNSearch, csas_search_factory
def generic_bad_exception_handler(error):
"""Handle generic exceptions."""
return BadRequestRESTError(description=str(error) or error.__doc__).get_response()
def elasticsearch_query_parsing_exception_handler(error):
"""Handle query parsing exceptions from ElasticSearch."""
current_app.logger.warning(error.info)
......@@ -80,11 +91,6 @@ def elasticsearch_version_conflict_engine_exception_handler(error):
return ConflictError(errors=[Error(str(error))]).get_response()
def create_error_handlers(blueprint):
"""Create error handlers on blueprint."""
records_rest_error_handlers(blueprint)
def build_url_action_for_pid(pid, action):
"""."""
return url_for(
......@@ -138,8 +144,6 @@ def build_ubq_blueprint(app):
url_prefix="",
)
create_error_handlers(blueprint)
endpoints = app.config.get("RECORDS_REST_ENDPOINTS", [])
pid_type = "recid"
endpoint = "ubq"
......@@ -193,21 +197,30 @@ def build_ubq_blueprint(app):
methods=["PUT"],
)
return blueprint
return register_error_handlers(app, blueprint, ubq_view)
def register_error_handlers(app, blueprint, view):
"""Create error handlers on blueprint."""
error_handlers_registry = defaultdict(dict)
for endpoint, options in app.config.get("RECORDS_REST_ENDPOINTS").items():
options = deepcopy(options)
error_handlers = options.pop("error_handlers", {})
for exc_or_code, handler in error_handlers.items():
view_name = view.__name__
error_handlers_registry[exc_or_code][view_name] = handler
def register_extra_routes(app):
"""Register extra routes."""
register_extra_item_routes(app)
return create_error_handlers(blueprint, error_handlers_registry)
def register_extra_item_routes(app):
def create_csas_blueprint(app):
"""Register extra item routes."""
blueprint = Blueprint("csas", __name__, url_prefix="")
for pid_type, options in iteritems(app.config["RECORDS_REST_ENDPOINTS"]):
options = deepcopy(options)
endpoint = "record"
# Note that the '/api/' part is added transparently since this is an api_blueprint
options["item_route"] = "/{endpoint}/<pid(recid):pid_value>".format(endpoint=endpoint)
create_permission_factory = obj_or_import_string(options["create_permission_factory_imp"])
search_factory = obj_or_import_string(options["search_factory_imp"], default=csas_search_factory)
......@@ -223,14 +236,14 @@ def register_extra_item_routes(app):
mime: obj_or_import_string(func) for mime, func in options["record_serializers"].items()
}
# Specific for custom view
pid_minter = "external_recid"
# links_factory = obj_or_import_string(external_links_factory)
# record_loaders = {"application/json": obj_or_import_string(external_pid_csas_loader)}
# Add POST record/:id route
endpoint = "record"
# Note that the '/api/' part is added transparently since this is an api_blueprint
options["item_route"] = "/{endpoint}/<pid(recid):pid_value>".format(endpoint=endpoint)
item_view = ExternalIDRecordResource.as_view(
external_id_item_view = ExternalIDRecordResource.as_view(
ExternalIDRecordResource.view_name.format(pid_type),
minter_name=pid_minter,
minter_name="external_recid",
pid_type=pid_type,
pid_fetcher=options["pid_fetcher"],
create_permission_factory=create_permission_factory,
......@@ -243,12 +256,14 @@ def register_extra_item_routes(app):
search_factory=(obj_or_import_string(search_factory, default=csas_search_factory)),
)
app.add_url_rule(
blueprint.add_url_rule(
options["item_route"],
view_func=item_view,
view_func=external_id_item_view,
methods=["POST"],
)
return register_error_handlers(app, blueprint, external_id_item_view)
def pass_bucket_content_id(f: Callable):
"""Decorate to retrieve a bucket."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment