using System;
using System.Linq;
using System.Threading.Tasks;
using EmbedIO.Utilities;
namespace EmbedIO.Cors
{
///
/// Cross-origin resource sharing (CORS) control Module.
/// CORS is a mechanism that allows restricted resources (e.g. fonts)
/// on a web page to be requested from another domain outside the domain from which the resource originated.
///
public class CorsModule : WebModuleBase
{
///
/// A string meaning "All" in CORS headers.
///
public const string All = "*";
private readonly string _origins;
private readonly string _headers;
private readonly string _methods;
private readonly string[] _validOrigins;
private readonly string[] _validMethods;
///
/// Initializes a new instance of the class.
///
/// The base route.
/// The valid origins. The default is (*).
/// The valid headers. The default is (*).
/// The valid methods. The default is (*).
///
/// origins
/// or
/// headers
/// or
/// methods
///
public CorsModule(
string baseRoute,
string origins = All,
string headers = All,
string methods = All)
: base(baseRoute)
{
_origins = origins ?? throw new ArgumentNullException(nameof(origins));
_headers = headers ?? throw new ArgumentNullException(nameof(headers));
_methods = methods ?? throw new ArgumentNullException(nameof(methods));
_validOrigins =
origins.ToLowerInvariant()
.SplitByComma(StringSplitOptions.RemoveEmptyEntries)
.Select(x => x.Trim())
.ToArray();
_validMethods =
methods.ToLowerInvariant()
.SplitByComma(StringSplitOptions.RemoveEmptyEntries)
.Select(x => x.Trim())
.ToArray();
}
///
public override bool IsFinalHandler => false;
///
protected override Task OnRequestAsync(IHttpContext context)
{
var isOptions = context.Request.HttpVerb == HttpVerbs.Options;
// If we allow all we don't need to filter
if (_origins == All && _headers == All && _methods == All)
{
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowOrigin, All);
if (isOptions)
{
ValidateHttpOptions(context);
context.SetHandled();
}
return Task.CompletedTask;
}
var currentOrigin = context.Request.Headers[HttpHeaderNames.Origin];
if (string.IsNullOrWhiteSpace(currentOrigin) && context.Request.IsLocal)
return Task.CompletedTask;
if (_origins == All)
return Task.CompletedTask;
if (_validOrigins.Contains(currentOrigin))
{
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowOrigin, currentOrigin);
if (isOptions)
{
ValidateHttpOptions(context);
context.SetHandled();
}
}
return Task.CompletedTask;
}
private void ValidateHttpOptions(IHttpContext context)
{
var requestHeadersHeader = context.Request.Headers[HttpHeaderNames.AccessControlRequestHeaders];
if (!string.IsNullOrWhiteSpace(requestHeadersHeader))
{
// TODO: Remove unwanted headers from request
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowHeaders, requestHeadersHeader);
}
var requestMethodHeader = context.Request.Headers[HttpHeaderNames.AccessControlRequestMethod];
if (string.IsNullOrWhiteSpace(requestMethodHeader))
return;
var currentMethods = requestMethodHeader.ToLowerInvariant()
.SplitByComma(StringSplitOptions.RemoveEmptyEntries)
.Select(x => x.Trim());
if (_methods != All && !currentMethods.Any(_validMethods.Contains))
throw HttpException.BadRequest();
context.Response.Headers.Set(HttpHeaderNames.AccessControlAllowMethods, requestMethodHeader);
}
}
}