using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using EmbedIO.Utilities;
using EmbedIO.WebSockets.Internal;
using Swan;
using Swan.Logging;
using Swan.Threading;
namespace EmbedIO.WebSockets
{
///
/// A base class for modules that handle WebSocket connections.
///
///
/// Each WebSocket server has a list of WebSocket subprotocols it can accept.
/// When a client initiates a WebSocket opening handshake:
///
/// - if the list of accepted subprotocols is empty,
/// the connection is accepted only if no SecWebSocketProtocol
/// header is present in the request;
/// - if the list of accepted subprotocols is not empty,
/// the connection is accepted only if one or more SecWebSocketProtocol
/// headers are present in the request and one of them specifies one
/// of the subprotocols in the list. The first subprotocol specified by the client
/// that is also present in the module's list is then specified in the
/// handshake response.
///
/// If a connection is not accepted because of a subprotocol mismatch,
/// a 400 Bad Request response is sent back to the client. The response
/// contains one or more SecWebSocketProtocol headers that specify
/// the list of accepted subprotocols (if any).
///
public abstract class WebSocketModule : WebModuleBase, IDisposable
{
private const int ReceiveBufferSize = 2048;
private readonly bool _enableConnectionWatchdog;
private readonly List _protocols = new List();
private readonly ConcurrentDictionary _contexts = new ConcurrentDictionary();
private bool _isDisposing;
private int _maxMessageSize;
private TimeSpan _keepAliveInterval;
private Encoding _encoding;
private PeriodicTask? _connectionWatchdog;
private bool _allowNullProtocol = false;
///
/// Initializes a new instance of the class.
///
/// The URL path of the WebSocket endpoint to serve.
/// If set to ,
/// contexts representing closed connections will automatically be purged
/// from every 30 seconds..
protected WebSocketModule(string urlPath, bool enableConnectionWatchdog)
: base(urlPath)
{
_enableConnectionWatchdog = enableConnectionWatchdog;
_maxMessageSize = 0;
_keepAliveInterval = TimeSpan.FromSeconds(30);
_encoding = Encoding.UTF8;
}
///
public sealed override bool IsFinalHandler => true;
///
/// Gets or sets the maximum size of a received message.
/// If a message exceeding the maximum size is received from a client,
/// the connection is closed automatically.
/// The default value is 0, which disables message size checking.
///
protected int MaxMessageSize
{
get => _maxMessageSize;
set
{
EnsureConfigurationNotLocked();
_maxMessageSize = Math.Max(value, 0);
}
}
///
/// Gets or sets the keep-alive interval for the WebSocket connection.
/// The default is 30 seconds.
///
/// This property is being set to a value
/// that is too small to be acceptable.
protected TimeSpan KeepAliveInterval
{
get => _keepAliveInterval;
set
{
EnsureConfigurationNotLocked();
if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero)
throw new ArgumentOutOfRangeException(nameof(value), "The specified keep-alive interval is too small.");
_keepAliveInterval = value;
}
}
///
/// Gets the used by the method
/// to send a string. The default is per the WebSocket specification.
///
/// This property is being set to .
protected Encoding Encoding
{
get => _encoding;
set
{
EnsureConfigurationNotLocked();
_encoding = Validate.NotNull(nameof(value), value);
}
}
///
/// Gets a list of interfaces
/// representing the currently connected clients.
///
protected IReadOnlyList ActiveContexts
{
get
{
// ConcurrentDictionary.Values, although declared as ICollection,
// will probably return a ReadOnlyCollection, which implements IReadOnlyList:
// https://referencesource.microsoft.com/#mscorlib/system/Collections/Concurrent/ConcurrentDictionary.cs,fe55c11912af21d2
// https://github.com/dotnet/corefx/blob/master/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs#L1990
// https://github.com/mono/mono/blob/master/mcs/class/referencesource/mscorlib/system/collections/Concurrent/ConcurrentDictionary.cs#L1961
// However there is no formal guarantee, so be ready to convert to a list, just in case.
var values = _contexts.Values;
return values is IReadOnlyList list
? list
: values.ToList();
}
}
///
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
///
protected sealed override async Task OnRequestAsync(IHttpContext context)
{
// The WebSocket endpoint must match exactly, giving a RequestedPath of "/".
// In all other cases the path is longer, so there's no need to compare strings here.
if (context.RequestedPath.Length > 1)
return;
var requestedProtocols = context.Request.Headers.GetValues(HttpHeaderNames.SecWebSocketProtocol)
?.Select(s => s.Trim())
.Where(s => s.Length > 0)
.ToArray()
?? Array.Empty();
string acceptedProtocol;
bool acceptConnection;
if (_protocols.Count > 0)
{
acceptedProtocol = requestedProtocols.FirstOrDefault(p => _protocols.Contains(p)) ?? string.Empty;
acceptConnection = acceptedProtocol.Length > 0;
}
else
{
acceptedProtocol = string.Empty;
acceptConnection = requestedProtocols.Length == 0;
}
if (!acceptConnection)
{
$"{BaseRoute} - Rejecting WebSocket connection: no subprotocol was accepted.".Debug(nameof(WebSocketModule));
foreach (var protocol in _protocols)
context.Response.Headers.Add(HttpHeaderNames.SecWebSocketProtocol, protocol);
// Not throwing a HTTP exception here because a WebSocket client
// does not care about nice, formatted messages.
context.Response.SetEmptyResponse((int)HttpStatusCode.BadRequest);
return;
}
var contextImpl = context.GetImplementation();
$"{BaseRoute} - Accepting WebSocket connection with subprotocol \"{acceptedProtocol}\"".Debug(nameof(WebSocketModule));
var webSocketContext = await contextImpl.AcceptWebSocketAsync(
requestedProtocols,
acceptedProtocol,
ReceiveBufferSize,
KeepAliveInterval,
context.CancellationToken).ConfigureAwait(false);
PurgeDisconnectedContexts();
_ = _contexts.TryAdd(webSocketContext.Id, webSocketContext);
$"{BaseRoute} - WebSocket connection accepted - There are now {_contexts.Count} sockets connected."
.Debug(nameof(WebSocketModule));
await OnClientConnectedAsync(webSocketContext).ConfigureAwait(false);
try
{
if (webSocketContext.WebSocket is SystemWebSocket systemWebSocket)
{
await ProcessSystemContext(
webSocketContext,
systemWebSocket.UnderlyingWebSocket,
context.CancellationToken).ConfigureAwait(false);
}
else
{
await ProcessEmbedIOContext(webSocketContext, context.CancellationToken)
.ConfigureAwait(false);
}
}
catch (TaskCanceledException)
{
// ignore
}
catch (Exception ex)
{
ex.Log(nameof(WebSocketModule));
}
finally
{
// once the loop is completed or connection aborted, remove the WebSocket
RemoveWebSocket(webSocketContext);
}
}
///
protected override void OnStart(CancellationToken cancellationToken)
{
if (_enableConnectionWatchdog)
{
_connectionWatchdog = new PeriodicTask(
TimeSpan.FromSeconds(30),
ct => {
PurgeDisconnectedContexts();
return Task.CompletedTask;
},
cancellationToken);
}
}
///
/// Adds a WebSocket subprotocol to the list of protocols supported by a .
///
/// The protocol name to add to the list.
/// is .
///
/// contains one or more invalid characters, as defined
/// in RFC6455, Section 4.3.
/// - or -
/// is already in the list of supported protocols.
///
/// The has already been started.
///
///
///
protected void AddProtocol(string protocol)
{
protocol = Validate.Rfc2616Token(nameof(protocol), protocol);
EnsureConfigurationNotLocked();
if (_protocols.Contains(protocol))
throw new ArgumentException("Duplicate WebSocket protocol name.", nameof(protocol));
_protocols.Add(protocol);
}
///
/// Adds one or more WebSocket subprotocols to the list of protocols supported by a .
///
/// The protocol names to add to the list.
///
/// is .
/// - or -
/// One or more of the strings in is .
///
///
/// One or more of the strings in
/// contains one or more invalid characters, as defined
/// in RFC6455, Section 4.3.
/// - or -
/// One or more of the strings in
/// is already in the list of supported protocols.
///
/// The has already been started.
///
/// This method enumerates just once; hence, if an exception is thrown
/// because one of the specified protocols is or contains invalid characters,
/// any preceding protocol is added to the list of supported protocols.
///
///
///
///
protected void AddProtocols(IEnumerable protocols)
{
protocols = Validate.NotNull(nameof(protocols), protocols);
EnsureConfigurationNotLocked();
foreach (var protocol in protocols.Select(p => Validate.Rfc2616Token(nameof(protocols), p)))
{
if (_protocols.Contains(protocol))
throw new ArgumentException("Duplicate WebSocket protocol name.", nameof(protocols));
_protocols.Add(protocol);
}
}
///
/// Adds one or more WebSocket subprotocols to the list of protocols supported by a .
///
/// The protocol names to add to the list.
///
/// is .
/// - or -
/// One or more of the strings in is .
///
///
/// One or more of the strings in
/// contains one or more invalid characters, as defined
/// in RFC6455, Section 4.3.
/// - or -
/// One or more of the strings in
/// is already in the list of supported protocols.
///
/// The has already been started.
///
/// This method performs validation checks on all specified before adding them
/// to the list of supported protocols; hence, if an exception is thrown
/// because one of the specified protocols is or contains invalid characters,
/// none of the specified protocol names are added to the list.
///
///
///
///
protected void AddProtocols(params string[] protocols)
{
protocols = Validate.NotNull(nameof(protocols), protocols);
if (protocols.Select(p => Validate.Rfc2616Token(nameof(protocols), p)).Any(protocol => _protocols.Contains(protocol)))
throw new ArgumentException("Duplicate WebSocket protocol name.", nameof(protocols));
EnsureConfigurationNotLocked();
_protocols.AddRange(protocols);
}
///
/// Sends a text payload.
///
/// The web socket.
/// The payload.
/// A representing the ongoing operation.
protected async Task SendAsync(IWebSocketContext context, string payload)
{
try
{
var buffer = _encoding.GetBytes(payload ?? string.Empty);
await context.WebSocket.SendAsync(buffer, true, context.CancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
ex.Log(nameof(WebSocketModule));
}
}
#pragma warning disable CA1822 // Member can be declared as static - It is an instance method for API consistency.
///
/// Sends a binary payload.
///
/// The web socket.
/// The payload.
/// A representing the ongoing operation.
protected async Task SendAsync(IWebSocketContext context, byte[] payload)
{
try
{
await context.WebSocket.SendAsync(payload ?? Array.Empty(), false, context.CancellationToken)
.ConfigureAwait(false);
}
catch (Exception ex)
{
ex.Log(nameof(WebSocketModule));
}
}
#pragma warning restore CA1822
///
/// Broadcasts the specified payload to all connected WebSocket clients.
///
/// The payload.
/// A representing the ongoing operation.
protected Task BroadcastAsync(byte[] payload)
=> Task.WhenAll(_contexts.Values.Select(c => SendAsync(c, payload)));
///
/// Broadcasts the specified payload to selected WebSocket clients.
///
/// The payload.
/// A callback function that must return
/// for each context to be included in the broadcast.
/// A representing the ongoing operation.
protected Task BroadcastAsync(byte[] payload, Func selector)
=> Task.WhenAll(_contexts.Values.Where(Validate.NotNull(nameof(selector), selector)).Select(c => SendAsync(c, payload)));
///
/// Broadcasts the specified payload to all connected WebSocket clients.
///
/// The payload.
/// A representing the ongoing operation.
protected Task BroadcastAsync(string payload)
=> Task.WhenAll(_contexts.Values.Select(c => SendAsync(c, payload)));
///
/// Broadcasts the specified payload to selected WebSocket clients.
///
/// The payload.
/// A callback function that must return
/// for each context to be included in the broadcast.
/// A representing the ongoing operation.
protected Task BroadcastAsync(string payload, Func selector)
=> Task.WhenAll(_contexts.Values.Where(Validate.NotNull(nameof(selector), selector)).Select(c => SendAsync(c, payload)));
///
/// Closes the specified web socket, removes it and disposes it.
///
/// The web socket.
/// A representing the ongoing operation.
protected async Task CloseAsync(IWebSocketContext context)
{
if (context == null)
return;
try
{
await context.WebSocket.CloseAsync(context.CancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
ex.Log(nameof(WebSocketModule));
}
finally
{
RemoveWebSocket(context);
}
}
///
/// Called when this WebSocket server receives a full message (EndOfMessage) from a client.
///
/// The context.
/// The buffer.
/// The result.
/// A representing the ongoing operation.
protected abstract Task OnMessageReceivedAsync(IWebSocketContext context, byte[] buffer, IWebSocketReceiveResult result);
///
/// Called when this WebSocket server receives a message frame regardless if the frame represents the EndOfMessage.
///
/// The context.
/// The buffer.
/// The result.
/// A representing the ongoing operation.
protected virtual Task OnFrameReceivedAsync(
IWebSocketContext context,
byte[] buffer,
IWebSocketReceiveResult result)
=> Task.CompletedTask;
///
/// Called when this WebSocket server accepts a new client.
///
/// The context.
/// A representing the ongoing operation.
protected virtual Task OnClientConnectedAsync(IWebSocketContext context) => Task.CompletedTask;
///
/// Called when the server has removed a connected client for any reason.
///
/// The context.
/// A representing the ongoing operation.
protected virtual Task OnClientDisconnectedAsync(IWebSocketContext context) => Task.CompletedTask;
///
/// Releases unmanaged and - optionally - managed resources.
///
/// true to release both managed and unmanaged resources; false to release only unmanaged resources.
protected virtual void Dispose(bool disposing)
{
if (_isDisposing)
return;
_isDisposing = true;
if (disposing)
{
_connectionWatchdog?.Dispose();
Task.WhenAll(_contexts.Values.Select(CloseAsync)).Await(false);
PurgeDisconnectedContexts();
}
}
private void RemoveWebSocket(IWebSocketContext context)
{
if (!_contexts.TryRemove(context.Id, out _))
{
return;
}
context.WebSocket?.Dispose();
// OnClientDisconnectedAsync is better called in its own task,
// so it may call methods that require a lock on _contextsAccess.
// Otherwise, calling e.g. Broadcast would result in a deadlock.
#pragma warning disable CS4014 // Call is not awaited - it is intentionally forked.
_ = Task.Run(async () => {
try
{
await OnClientDisconnectedAsync(context).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
$"[{context.Id}] OnClientDisconnectedAsync was canceled.".Debug(nameof(WebSocketModule));
}
catch (Exception e)
{
e.Log(nameof(WebSocketModule), $"[{context.Id}] Exception in OnClientDisconnectedAsync.");
}
});
#pragma warning restore CS4014
}
private void PurgeDisconnectedContexts()
{
var contexts = _contexts.Values;
var totalCount = _contexts.Count;
var purgedCount = 0;
foreach (var context in contexts)
{
if (context.WebSocket == null || context.WebSocket.State == WebSocketState.Open)
continue;
RemoveWebSocket(context);
purgedCount++;
}
$"{BaseRoute} - Purged {purgedCount} of {totalCount} sockets."
.Debug(nameof(WebSocketModule));
}
private async Task ProcessEmbedIOContext(IWebSocketContext context, CancellationToken cancellationToken)
{
((Internal.WebSocket)context.WebSocket).OnMessage += async (s, e) =>
{
if (e.Opcode == Opcode.Close)
{
await context.WebSocket.CloseAsync(context.CancellationToken).ConfigureAwait(false);
}
else
{
await OnMessageReceivedAsync(
context,
e.RawData,
new Internal.WebSocketReceiveResult(e.RawData.Length, e.Opcode))
.ConfigureAwait(false);
}
};
while (context.WebSocket.State == WebSocketState.Open
|| context.WebSocket.State == WebSocketState.CloseReceived
|| context.WebSocket.State == WebSocketState.CloseSent)
{
await Task.Delay(500, cancellationToken).ConfigureAwait(false);
}
}
private async Task ProcessSystemContext(IWebSocketContext context, System.Net.WebSockets.WebSocket webSocket, CancellationToken cancellationToken)
{
// define a receive buffer
var receiveBuffer = new byte[ReceiveBufferSize];
// define a dynamic buffer that holds multi-part receptions
var receivedMessage = new List(receiveBuffer.Length * 2);
// poll the WebSocket connections for reception
while (webSocket.State == WebSocketState.Open)
{
// retrieve the result (blocking)
var receiveResult = new SystemWebSocketReceiveResult(
await webSocket.ReceiveAsync(new ArraySegment(receiveBuffer), cancellationToken)
.ConfigureAwait(false));
if (receiveResult.MessageType == (int)WebSocketMessageType.Close)
{
// close the connection if requested by the client
await webSocket
.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken)
.ConfigureAwait(false);
return;
}
var frameBytes = new byte[receiveResult.Count];
Array.Copy(receiveBuffer, frameBytes, frameBytes.Length);
await OnFrameReceivedAsync(context, frameBytes, receiveResult).ConfigureAwait(false);
// add the response to the multi-part response
receivedMessage.AddRange(frameBytes);
if (_maxMessageSize > 0 && receivedMessage.Count > _maxMessageSize)
{
// close the connection if message exceeds max length
await webSocket.CloseAsync(
WebSocketCloseStatus.MessageTooBig,
$"Message too big. Maximum is {_maxMessageSize} bytes.",
cancellationToken).ConfigureAwait(false);
// exit the loop; we're done
return;
}
// if we're at the end of the message, process the message
if (!receiveResult.EndOfMessage) continue;
await OnMessageReceivedAsync(context, receivedMessage.ToArray(), receiveResult)
.ConfigureAwait(false);
receivedMessage.Clear();
}
}
}
}