diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index e507b2299..bdf1358bf 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -8,12 +8,14 @@ using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; +using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Logging; namespace Bit.Identity.Controllers { [Route("accounts")] + [ExceptionHandlerFilter] public class AccountsController : Controller { private readonly ILogger _logger; diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index 37afbb88b..a38b33fa0 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -58,7 +58,6 @@ namespace Bit.Identity services.AddMemoryCache(); // Mvc - // MVC services.AddMvc(config => { config.Filters.Add(new ModelStateValidationFilterAttribute()); diff --git a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs new file mode 100644 index 000000000..8723673df --- /dev/null +++ b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs @@ -0,0 +1,93 @@ +using System; +using Bit.Core.Exceptions; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.IdentityModel.Tokens; +using Stripe; +using InternalApi = Bit.Core.Models.Api; + +namespace Bit.SharedWeb.Utilities +{ + public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute + { + public ExceptionHandlerFilterAttribute() + { + } + + public override void OnException(ExceptionContext context) + { + var errorMessage = "An error has occurred."; + + var exception = context.Exception; + if (exception == null) + { + // Should never happen. + return; + } + + InternalApi.ErrorResponseModel internalErrorModel = null; + if (exception is BadRequestException badRequestException) + { + context.HttpContext.Response.StatusCode = 400; + if (badRequestException.ModelState != null) + { + internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); + } + else + { + errorMessage = badRequestException.Message; + } + } + else if (exception is GatewayException) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is ApplicationException) + { + context.HttpContext.Response.StatusCode = 402; + } + else if (exception is NotFoundException) + { + errorMessage = "Resource not found."; + context.HttpContext.Response.StatusCode = 404; + } + else if (exception is SecurityTokenValidationException) + { + errorMessage = "Invalid token."; + context.HttpContext.Response.StatusCode = 403; + } + else if (exception is UnauthorizedAccessException) + { + errorMessage = "Unauthorized."; + context.HttpContext.Response.StatusCode = 401; + } + else + { + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + logger.LogError(0, exception, exception.Message); + errorMessage = "An unhandled server error has occurred."; + context.HttpContext.Response.StatusCode = 500; + } + + var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); + var env = context.HttpContext.RequestServices.GetRequiredService(); + if (env.IsDevelopment()) + { + errorModel.ExceptionMessage = exception.Message; + errorModel.ExceptionStackTrace = exception.StackTrace; + errorModel.InnerExceptionMessage = exception?.InnerException?.Message; + } + context.Result = new ObjectResult(errorModel); + } + } +}