130 lines
4.8 KiB
C#
130 lines
4.8 KiB
C#
using System;
|
|
using System.Linq;
|
|
using System.Threading.Tasks;
|
|
using EmbedIO.Utilities;
|
|
|
|
namespace EmbedIO.Cors
|
|
{
|
|
/// <summary>
|
|
/// Cross-origin resource sharing (CORS) control Module.
|
|
/// CORS is a mechanism that allows restricted resources (e.g. fonts)
|
|
/// on a web page to be requested from another domain outside the domain from which the resource originated.
|
|
/// </summary>
|
|
public class CorsModule : WebModuleBase
|
|
{
|
|
/// <summary>
|
|
/// A string meaning "All" in CORS headers.
|
|
/// </summary>
|
|
public const string All = "*";
|
|
|
|
private readonly string _origins;
|
|
private readonly string _headers;
|
|
private readonly string _methods;
|
|
private readonly string[] _validOrigins;
|
|
private readonly string[] _validMethods;
|
|
|
|
/// <summary>
|
|
/// Initializes a new instance of the <see cref="CorsModule" /> class.
|
|
/// </summary>
|
|
/// <param name="baseRoute">The base route.</param>
|
|
/// <param name="origins">The valid origins. The default is <see cref="All"/> (<c>*</c>).</param>
|
|
/// <param name="headers">The valid headers. The default is <see cref="All"/> (<c>*</c>).</param>
|
|
/// <param name="methods">The valid methods. The default is <see cref="All"/> (<c>*</c>).</param>
|
|
/// <exception cref="ArgumentNullException">
|
|
/// origins
|
|
/// or
|
|
/// headers
|
|
/// or
|
|
/// methods
|
|
/// </exception>
|
|
public CorsModule(
|
|
string baseRoute,
|
|
string origins = All,
|
|
string headers = All,
|
|
string methods = All)
|
|
: base(baseRoute)
|
|
{
|
|
_origins = origins ?? throw new ArgumentNullException(nameof(origins));
|
|
_headers = headers ?? throw new ArgumentNullException(nameof(headers));
|
|
_methods = methods ?? throw new ArgumentNullException(nameof(methods));
|
|
|
|
_validOrigins =
|
|
origins.ToLowerInvariant()
|
|
.SplitByComma(StringSplitOptions.RemoveEmptyEntries)
|
|
.Select(x => x.Trim())
|
|
.ToArray();
|
|
_validMethods =
|
|
methods.ToLowerInvariant()
|
|
.SplitByComma(StringSplitOptions.RemoveEmptyEntries)
|
|
.Select(x => x.Trim())
|
|
.ToArray();
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public override bool IsFinalHandler => false;
|
|
|
|
/// <inheritdoc />
|
|
protected override Task OnRequestAsync(IHttpContext context)
|
|
{
|
|
var isOptions = context.Request.HttpVerb == HttpVerbs.Options;
|
|
|
|
// If we allow all we don't need to filter
|
|
if (_origins == All && _headers == All && _methods == All)
|
|
{
|
|
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowOrigin, All);
|
|
|
|
if (isOptions)
|
|
{
|
|
ValidateHttpOptions(context);
|
|
context.SetHandled();
|
|
}
|
|
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
var currentOrigin = context.Request.Headers[HttpHeaderNames.Origin];
|
|
|
|
if (string.IsNullOrWhiteSpace(currentOrigin) && context.Request.IsLocal)
|
|
return Task.CompletedTask;
|
|
|
|
if (_origins == All)
|
|
return Task.CompletedTask;
|
|
|
|
if (_validOrigins.Contains(currentOrigin))
|
|
{
|
|
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowOrigin, currentOrigin);
|
|
|
|
if (isOptions)
|
|
{
|
|
ValidateHttpOptions(context);
|
|
context.SetHandled();
|
|
}
|
|
}
|
|
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
private void ValidateHttpOptions(IHttpContext context)
|
|
{
|
|
var requestHeadersHeader = context.Request.Headers[HttpHeaderNames.AccessControlRequestHeaders];
|
|
if (!string.IsNullOrWhiteSpace(requestHeadersHeader))
|
|
{
|
|
// TODO: Remove unwanted headers from request
|
|
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowHeaders, requestHeadersHeader);
|
|
}
|
|
|
|
var requestMethodHeader = context.Request.Headers[HttpHeaderNames.AccessControlRequestMethod];
|
|
if (string.IsNullOrWhiteSpace(requestMethodHeader))
|
|
return;
|
|
|
|
var currentMethods = requestMethodHeader.ToLowerInvariant()
|
|
.SplitByComma(StringSplitOptions.RemoveEmptyEntries)
|
|
.Select(x => x.Trim());
|
|
|
|
if (_methods != All && !currentMethods.Any(_validMethods.Contains))
|
|
throw HttpException.BadRequest();
|
|
|
|
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowMethods, requestMethodHeader);
|
|
}
|
|
}
|
|
} |