简体   繁体   中英

A C# await-able SocketAsyncEventArgs wrapper for a UDP server; socket.ReceiveAsync returns invalid RemoteEndPoint (0.0.0.0:0)

As the title says, I am writing a UDP server on raw sockets, and using SocketAsyncEventArgs because i would like to write something fast.

I am aware that UdpClient exists, and that there are easier methods to go about writing a server, but i would like to learn how to properly use SocketAsyncEventArgs and the socket.ReceiveFromAsync / socket.SendToAsync methods for their claimed 'enchanced throughput' and 'better scalability' ( MSDN Docs for SocketAsyncEventArgs ).

I have essentially followed the MSDN example, as I figured it was a decent starting point to learn how these methods work, and ran into a bit of a problem. The server works initially and can echo back received bytes, but will randomly 'fail' to receive bytes from the correct address. Instead of the correct localhost client address (eg 127.0.0.1:7007) the RemoteEndPoint will be populated by the UDP placeholder EndPoint {0.0.0.0:0} (for lack of a better term). Image showing the problem (console is in Polish, sorry. Just trust me that the SocketException message is "Desired Address is invalid in this context").

I have wholesale ripped chunks from the MSDN example at times, changing only the fields that were populated on the SocketAsyncEventArgs instance for the socket.ReceiveFromAsync call (in accordance with the MSDN docs -> socket.ReceiveFromAsync Docs ), and the end result was still the same. Furthermore, this is an intermittent issue, not a constant one. There is no time that the server will consistently error out, from what I have noticed.

My thoughts so far are a state issue with UdpServer, some inconsistency on the UdpClient side, or a misuse of the TaskCompletionSource.

Edit 1:

I feel that i should address why im using SocketAsyncEventArgs. I completely understand that there are easier methods to go about sending and receiving data. The async/await socket extensions are a good way to go about this, and are the way i did it initially. I would want to benchmark async/await versus the older api, the SocketArgs, to see by how much the two methods differ.


Code for the UdpClient, UdpServer, and shared structures is included below. I can also try to provide more code on demand, if StackOverflow will let me.

Thank you for taking the time to help me.

Test Code

private static async Task TestNetworking()
        {
            EndPoint serverEndPoint = new IPEndPoint(IPAddress.Loopback, 12345);

            await Task.Factory.StartNew(async () =>
            {
                SocketClient client = new UdpClient();
                bool bound = client.Bind(new IPEndPoint(IPAddress.Any, 7007));
                if (bound)
                {
                    Console.WriteLine($"[Client] Bound client socket!");
                }

                if (bound && client.Connect(serverEndPoint))
                {
                    Console.WriteLine($"[Client] Connected to {serverEndPoint}!");

                    byte[] message = Encoding.UTF8.GetBytes("Hello World!");

                    Stopwatch stopwatch = new Stopwatch();

                    const int packetsToSend = 1_000_000;

                    for (int i = 0; i < packetsToSend; i++)
                    {
                        try
                        {
                            stopwatch.Start();

                            int sentBytes = await client.SendAsync(serverEndPoint, message, SocketFlags.None);

                            //Console.WriteLine($"[Client] Sent {sentBytes} to {serverEndPoint}");

                            ReceiveResult result = await client.ReceiveAsync(serverEndPoint, SocketFlags.None);

                            //Console.WriteLine($"[{result.RemoteEndPoint} > Client] : {Encoding.UTF8.GetString(result.Contents)}");
                            serverEndPoint = result.RemoteEndPoint;

                            stopwatch.Stop();
                        }
                        catch (Exception ex)
                        {
                            Console.WriteLine(ex);
                            i--;
                            await Task.Delay(1);
                        }
                    }

                    double approxBandwidth = (packetsToSend * message.Length) / (1_000_000.0 * (stopwatch.ElapsedMilliseconds / 1000.0));

                    Console.WriteLine($"Sent {packetsToSend} packets of {message.Length} bytes in {stopwatch.ElapsedMilliseconds:N} milliseconds.");
                    Console.WriteLine($"Approximate bandwidth: {approxBandwidth} MBps");
                }
            }, TaskCreationOptions.LongRunning);

            await Task.Factory.StartNew(async () =>
            {
                try
                {
                    SocketServer server = new UdpServer();
                    bool bound = server.Bind(serverEndPoint);
                    if (bound)
                    {
                        //Console.WriteLine($"[Server] Bound server socket!");

                        //Console.WriteLine($"Starting server at {serverEndPoint}!");

                        await server.StartAsync();
                    }
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);
                }
            }).Result;
        }

Shared Code

public readonly struct ReceiveResult
    {
        public const int PacketSize = 1024;

        public readonly Memory<byte> Contents;

        public readonly int ReceivedBytes;

        public readonly EndPoint RemoteEndPoint;

        public ReceiveResult(Memory<byte> contents, int receivedBytes, EndPoint remoteEndPoint)
        {
            Contents = contents;

            ReceivedBytes = receivedBytes;

            RemoteEndPoint = remoteEndPoint;
        }
    }

UDP Client

public class UdpClient : SocketClient
    {
        /*
        public abstract class SocketClient
        {
            protected readonly Socket socket;

            protected SocketClient(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
            {
                socket = new Socket(addressFamily, socketType, protocolType);
            }

            public bool Bind(in EndPoint localEndPoint)
            {
                try
                {
                    socket.Bind(localEndPoint);

                    return true;
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);

                    return false;
                }
            }

            public bool Connect(in EndPoint remoteEndPoint)
            {
                try
                {
                    socket.Connect(remoteEndPoint);

                    return true;
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);

                    return false;
                }
            }

            public abstract Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags);

            public abstract Task<int> SendAsync(EndPoint remoteEndPoint, ArraySegment<byte> buffer, SocketFlags socketFlags);
        }
        */
        /// <inheritdoc />
        public UdpClient() : base(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)
        {
        }

        public override async Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags)
        {
            byte[] buffer = new byte[ReceiveResult.PacketSize];

            SocketReceiveFromResult result =
                await socket.ReceiveFromAsync(new ArraySegment<byte>(buffer), socketFlags, remoteEndPoint);

            return new ReceiveResult(buffer, result.ReceivedBytes, result.RemoteEndPoint);

            /*
            SocketAsyncEventArgs args = new SocketAsyncEventArgs();
            args.SetBuffer(new byte[ReceiveResult.PacketSize]);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            SocketTask awaitable = new SocketTask(args);

            while (ReceiveResult.PacketSize > args.BytesTransferred)
            {
                await socket.ReceiveFromAsync(awaitable);
            }

            return new ReceiveResult(args.MemoryBuffer, args.RemoteEndPoint);
            */
        }

        public override async Task<int> SendAsync(EndPoint remoteEndPoint, ArraySegment<byte> buffer, SocketFlags socketFlags)
        {
            return await socket.SendToAsync(buffer.ToArray(), socketFlags, remoteEndPoint);

            /*
            SocketAsyncEventArgs args = new SocketAsyncEventArgs();
            args.SetBuffer(buffer);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            SocketTask awaitable = new SocketTask(args);

            while (buffer.Length > args.BytesTransferred)
            {
                await socket.SendToAsync(awaitable);
            }

            return args.BytesTransferred;
            */
        }
    }

UDP Server

public class UdpServer : SocketServer
    {
        /*
        public abstract class SocketServer
        {
            protected readonly Socket socket;

            protected SocketServer(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
            {
                socket = new Socket(addressFamily, socketType, protocolType);
            }

            public bool Bind(in EndPoint localEndPoint)
            {
                try
                {
                    socket.Bind(localEndPoint);

                    return true;
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);

                    return false;
                }
            }

            public abstract Task StartAsync();
        }
         */

        private const int MaxPooledObjects = 100;
        private readonly ConcurrentDictionary<EndPoint, ConcurrentQueue<byte[]>> clients;

        private readonly ArrayPool<byte> receiveBufferPool = ArrayPool<byte>.Create(ReceiveResult.PacketSize, MaxPooledObjects);

        private readonly ObjectPool<SocketAsyncEventArgs> receiveSocketAsyncEventArgsPool =
            new DefaultObjectPool<SocketAsyncEventArgs>(new DefaultPooledObjectPolicy<SocketAsyncEventArgs>(), MaxPooledObjects);

        private readonly ObjectPool<SocketAsyncEventArgs> sendSocketAsyncEventArgsPool =
            new DefaultObjectPool<SocketAsyncEventArgs>(new DefaultPooledObjectPolicy<SocketAsyncEventArgs>(), MaxPooledObjects);

        private void HandleIOCompleted(object? sender, SocketAsyncEventArgs eventArgs)
        {
            eventArgs.Completed -= HandleIOCompleted;
            bool closed = false;

            /*
             Original (local) methods in ReceiveAsync and SendAsync,
             these were assigned to eventArgs.Completed instead of HandleIOCompleted
             =======================================================================

            void ReceiveCompletedHandler(object? sender, SocketAsyncEventArgs eventArgs)
            {
                AsyncReadToken asyncReadToken = (AsyncReadToken)eventArgs.UserToken;
                eventArgs.Completed -= ReceiveCompletedHandler;

                if (eventArgs.SocketError != SocketError.Success)
                {
                    asyncReadToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
                }
                else
                {
                    eventArgs.MemoryBuffer.CopyTo(asyncReadToken.OutputBuffer);

                    asyncReadToken.CompletionSource.TrySetResult(
                        new ReceiveResult(asyncReadToken.OutputBuffer, eventArgs.BytesTransferred, eventArgs.RemoteEndPoint));
                }

                receiveBufferPool.Return(asyncReadToken.RentedBuffer);
                receiveSocketAsyncEventArgsPool.Return(eventArgs);
            }

            void SendCompletedHandler(object? sender, SocketAsyncEventArgs eventArgs)
            {
                AsyncWriteToken asyncWriteToken = (AsyncWriteToken)eventArgs.UserToken;
                eventArgs.Completed -= SendCompletedHandler;

                if (eventArgs.SocketError != SocketError.Success)
                {
                    asyncWriteToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
                }
                else
                {
                    asyncWriteToken.CompletionSource.TrySetResult(eventArgs.BytesTransferred);
                }

                sendSocketAsyncEventArgsPool.Return(eventArgs);
            }
            */

            switch (eventArgs.LastOperation)
            {
                case SocketAsyncOperation.SendTo:
                    AsyncWriteToken asyncWriteToken = (AsyncWriteToken)eventArgs.UserToken;

                    if (eventArgs.SocketError != SocketError.Success)
                    {
                        asyncWriteToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
                    }
                    else
                    {
                        asyncWriteToken.CompletionSource.TrySetResult(eventArgs.BytesTransferred);
                    }

                    sendSocketAsyncEventArgsPool.Return(eventArgs);

                    break;

                case SocketAsyncOperation.ReceiveFrom:
                    AsyncReadToken asyncReadToken = (AsyncReadToken)eventArgs.UserToken;

                    if (eventArgs.SocketError != SocketError.Success)
                    {
                        asyncReadToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
                    }
                    else
                    {
                        eventArgs.MemoryBuffer.CopyTo(asyncReadToken.OutputBuffer);

                        asyncReadToken.CompletionSource.TrySetResult(
                            new ReceiveResult(asyncReadToken.OutputBuffer, eventArgs.BytesTransferred, eventArgs.RemoteEndPoint));
                    }

                    receiveBufferPool.Return(asyncReadToken.RentedBuffer);
                    receiveSocketAsyncEventArgsPool.Return(eventArgs);

                    break;

                case SocketAsyncOperation.Disconnect:
                    closed = true;
                    break;

                case SocketAsyncOperation.Accept:
                case SocketAsyncOperation.Connect:
                case SocketAsyncOperation.None:
                    break;
            }

            if (closed)
            {
                // handle the client closing the connection on tcp servers at some point
            }
        }

        private Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags, Memory<byte> outputBuffer)
        {
            TaskCompletionSource<ReceiveResult> tcs = new TaskCompletionSource<ReceiveResult>();

            byte[] buffer = receiveBufferPool.Rent(ReceiveResult.PacketSize);
            Memory<byte> memoryBuffer = new Memory<byte>(buffer);

            SocketAsyncEventArgs args = receiveSocketAsyncEventArgsPool.Get();
            args.SetBuffer(memoryBuffer);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            args.UserToken = new AsyncReadToken(buffer, outputBuffer, tcs);
            args.Completed += HandleIOCompleted;

            if (socket.ReceiveFromAsync(args)) return tcs.Task;

            byte[] bufferCopy = new byte[ReceiveResult.PacketSize];

            args.MemoryBuffer.CopyTo(bufferCopy);

            ReceiveResult result = new ReceiveResult(bufferCopy, args.BytesTransferred, args.RemoteEndPoint);

            receiveBufferPool.Return(buffer);
            receiveSocketAsyncEventArgsPool.Return(args);

            return Task.FromResult(result);
        }

        private Task<int> SendAsync(EndPoint remoteEndPoint, Memory<byte> buffer, SocketFlags socketFlags)
        {
            TaskCompletionSource<int> tcs = new TaskCompletionSource<int>();

            SocketAsyncEventArgs args = sendSocketAsyncEventArgsPool.Get();
            args.SetBuffer(buffer);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            args.UserToken = new AsyncWriteToken(buffer, tcs);
            args.Completed += HandleIOCompleted;

            if (socket.SendToAsync(args)) return tcs.Task;

            int result = args.BytesTransferred;
            sendSocketAsyncEventArgsPool.Return(args);

            return Task.FromResult(result);
        }

        private readonly struct AsyncReadToken
        {
            public readonly TaskCompletionSource<ReceiveResult> CompletionSource;

            public readonly Memory<byte> OutputBuffer;
            public readonly byte[] RentedBuffer;

            public AsyncReadToken(byte[] rentedBuffer, Memory<byte> outputBuffer, TaskCompletionSource<ReceiveResult> tcs)
            {
                RentedBuffer = rentedBuffer;
                OutputBuffer = outputBuffer;

                CompletionSource = tcs;
            }
        }

        private readonly struct AsyncWriteToken
        {
            public readonly TaskCompletionSource<int> CompletionSource;

            public readonly Memory<byte> OutputBuffer;

            public AsyncWriteToken(Memory<byte> outputBuffer, TaskCompletionSource<int> tcs)
            {
                OutputBuffer = outputBuffer;

                CompletionSource = tcs;
            }
        }

        public UdpServer() : base(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)
        {
            clients = new ConcurrentDictionary<EndPoint, ConcurrentQueue<byte[]>>();
        }

        /// <inheritdoc />
        public override async Task StartAsync()
        {
            EndPoint nullEndPoint = new IPEndPoint(IPAddress.Any, 0);
            byte[] receiveBuffer = new byte[ReceiveResult.PacketSize];
            Memory<byte> receiveBufferMemory = new Memory<byte>(receiveBuffer);

            while (true)
            {
                ReceiveResult result = await ReceiveAsync(nullEndPoint, SocketFlags.None, receiveBufferMemory);

                Console.WriteLine($"[{result.RemoteEndPoint} > Server] : {Encoding.UTF8.GetString(result.Contents.Span)}");

                int sentBytes = await SendAsync(result.RemoteEndPoint, result.Contents, SocketFlags.None);
                Console.WriteLine($"[Server > {result.RemoteEndPoint}] Sent {sentBytes} bytes to {result.RemoteEndPoint}");
            }
        }
    }

I managed to solve the problem!

I ended up having to merge the SocketAsyncEventArgs pool, as it turns out that you need to persist a single args object for the duration of a receive and send call. Now, my SendToAsync function takes a SocketAsyncEventArgs object (which is rented in the ReceiveFromAsync call) which contains the RemoteEndPoint for the client to which to send the response. The SendToAsync function is the one to clean up the SocketAsyncEventArgs, and return them to the pool.

Another issue with my earlier solution was the assignment of an Event multiple times. When i merged the two socket args pools, i left the multiple event handler assignment, which ended up causing problems. Once that was fixed, the solution worked completely as intended, and can send 1 000 000 packets (of 1Kb) without any problems. Really early testing (probably off by a bit) showed a bandwidth of around 5 megabytes per second (around 40 megabits per second) which is acceptable and much better than i was getting with my own overcomplicated 'fast async' version of the code.

With regard to the bandwidth, my fast async version was overcomplicated and thus not really comparable, but i believe that this SocketAsyncEventArgs version can serve as a good jumping off point for both benchmarking and tinkering to squeeze as much performance out of a socket as possible. I would still like feedback on this though, and will probably post it to the Code Review stack exchange at some point, as i doubt that there are no subtle bugs still in the solution.

Whoever wants to use this code is free to, it ended up being much simpler and easier to create than anticipated, but i dont take any responsibility if you are stupid enough to use this in production without extensive testing (this is a learning project after all).


Test Code:

private static async Task TestNetworking()
        {
            EndPoint serverEndPoint = new IPEndPoint(IPAddress.Loopback, 12345);

            await Task.Factory.StartNew(async () =>
            {
                try
                {
                    SocketServer server = new UdpServer();
                    bool bound = server.Bind(serverEndPoint);
                    if (bound)
                    {
                        Console.WriteLine($"[Server] Bound server socket!");

                        Console.WriteLine($"[Server] Starting server at {serverEndPoint}!");

                        await server.StartAsync();
                    }
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex);
                }
            }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default);

            await Task.Factory.StartNew(async () =>
            {
                SocketClient client = new UdpClient();
                bool bound = client.Bind(new IPEndPoint(IPAddress.Any, 7007));
                if (bound)
                {
                    Console.WriteLine($"[Client] Bound client socket!");
                }

                if (bound && client.Connect(serverEndPoint))
                {
                    Console.WriteLine($"[Client] Connected to {serverEndPoint}!");

                    byte[] message = Encoding.UTF8.GetBytes("Hello World!");
                    Memory<byte> messageBuffer = new Memory<byte>(message);

                    byte[] response = new byte[ReceiveResult.PacketSize];
                    Memory<byte> responseBuffer = new Memory<byte>(response);

                    Stopwatch stopwatch = new Stopwatch();

                    const int packetsToSend = 1_000_000, statusPacketThreshold = 10_000;

                    Console.WriteLine($"Started sending packets (total packet count: {packetsToSend})");

                    for (int i = 0; i < packetsToSend; i++)
                    {
                        if (i % statusPacketThreshold == 0)
                        {
                            Console.WriteLine($"Sent {i} packets out of {packetsToSend} ({((double)i / packetsToSend) * 100:F2}%)");
                        }

                        try
                        {
                            //Console.WriteLine($"[Client > {serverEndPoint}] Sending packet {i}");
                            stopwatch.Start();

                            int sentBytes = await client.SendAsync(serverEndPoint, messageBuffer, SocketFlags.None);

                            //Console.WriteLine($"[Client] Sent {sentBytes} to {serverEndPoint}");

                            ReceiveResult result = await client.ReceiveAsync(serverEndPoint, SocketFlags.None, responseBuffer);

                            //Console.WriteLine($"[{result.RemoteEndPoint} > Client] : {Encoding.UTF8.GetString(result.Contents)}");
                            serverEndPoint = result.RemoteEndPoint;

                            stopwatch.Stop();
                        }
                        catch (Exception ex)
                        {
                            Console.WriteLine(ex);
                            i--;
                            await Task.Delay(1);
                        }
                    }

                    double approxBandwidth = (packetsToSend * ReceiveResult.PacketSize) / (1_000_000.0 * (stopwatch.ElapsedMilliseconds / 1000.0));

                    Console.WriteLine($"Sent {packetsToSend} packets of {ReceiveResult.PacketSize} bytes in {stopwatch.ElapsedMilliseconds:N} milliseconds.");
                    Console.WriteLine($"Approximate bandwidth: {approxBandwidth} MBps");
                }
            }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).Result;
        }

Shared Code:

internal readonly struct AsyncReadToken
    {
        public readonly CancellationToken CancellationToken;
        public readonly TaskCompletionSource<ReceiveResult> CompletionSource;
        public readonly byte[] RentedBuffer;
        public readonly Memory<byte> UserBuffer;

        public AsyncReadToken(byte[] rentedBuffer, Memory<byte> userBuffer, TaskCompletionSource<ReceiveResult> tcs,
            CancellationToken cancellationToken = default)
        {
            RentedBuffer = rentedBuffer;
            UserBuffer = userBuffer;

            CompletionSource = tcs;
            CancellationToken = cancellationToken;
        }
    }

internal readonly struct AsyncWriteToken
    {
        public readonly CancellationToken CancellationToken;
        public readonly TaskCompletionSource<int> CompletionSource;
        public readonly byte[] RentedBuffer;
        public readonly Memory<byte> UserBuffer;

        public AsyncWriteToken(byte[] rentedBuffer, Memory<byte> userBuffer, TaskCompletionSource<int> tcs,
            CancellationToken cancellationToken = default)
        {
            RentedBuffer = rentedBuffer;
            UserBuffer = userBuffer;

            CompletionSource = tcs;
            CancellationToken = cancellationToken;
        }
    }

public readonly struct ReceiveResult
    {
        public const int PacketSize = 1024;

        public readonly SocketAsyncEventArgs ClientArgs;

        public readonly Memory<byte> Contents;

        public readonly int Count;

        public readonly EndPoint RemoteEndPoint;

        public ReceiveResult(SocketAsyncEventArgs clientArgs, Memory<byte> contents, int count, EndPoint remoteEndPoint)
        {
            ClientArgs = clientArgs;

            Contents = contents;
            Count = count;
            RemoteEndPoint = remoteEndPoint;
        }
    }

Server Code:

public abstract class SocketServer
    {
        protected readonly Socket socket;

        protected SocketServer(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
        {
            socket = new Socket(addressFamily, socketType, protocolType);
        }

        public bool Bind(in EndPoint localEndPoint)
        {
            try
            {
                socket.Bind(localEndPoint);

                return true;
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex);

                return false;
            }
        }

        public abstract Task StartAsync();
    }

public class UdpServer : SocketServer
    {
        private const int MaxPooledObjects = 1;
        private readonly ConcurrentDictionary<EndPoint, ConcurrentQueue<byte[]>> clients;

        private readonly ArrayPool<byte> receiveBufferPool =
            ArrayPool<byte>.Create(ReceiveResult.PacketSize, MaxPooledObjects);

        private readonly ArrayPool<byte> sendBufferPool =
            ArrayPool<byte>.Create(ReceiveResult.PacketSize, MaxPooledObjects);

        private readonly ObjectPool<SocketAsyncEventArgs> socketAsyncEventArgsPool =
            new DefaultObjectPool<SocketAsyncEventArgs>(new DefaultPooledObjectPolicy<SocketAsyncEventArgs>(),
                MaxPooledObjects);

        private void HandleIOCompleted(object? sender, SocketAsyncEventArgs eventArgs)
        {
            bool closed = false;

            switch (eventArgs.LastOperation)
            {
                case SocketAsyncOperation.SendTo:
                    AsyncWriteToken asyncWriteToken = (AsyncWriteToken)eventArgs.UserToken;

                    if (asyncWriteToken.CancellationToken.IsCancellationRequested)
                    {
                        asyncWriteToken.CompletionSource.TrySetCanceled();
                    }
                    else
                    {
                        if (eventArgs.SocketError != SocketError.Success)
                        {
                            asyncWriteToken.CompletionSource.TrySetException(
                                new SocketException((int)eventArgs.SocketError));
                        }
                        else
                        {
                            asyncWriteToken.CompletionSource.TrySetResult(eventArgs.BytesTransferred);
                        }
                    }

                    sendBufferPool.Return(asyncWriteToken.RentedBuffer, true);
                    socketAsyncEventArgsPool.Return(eventArgs);

                    break;

                case SocketAsyncOperation.ReceiveFrom:
                    AsyncReadToken asyncReadToken = (AsyncReadToken)eventArgs.UserToken;

                    if (asyncReadToken.CancellationToken.IsCancellationRequested)
                    {
                        asyncReadToken.CompletionSource.SetCanceled();
                    }
                    else
                    {
                        if (eventArgs.SocketError != SocketError.Success)
                        {
                            asyncReadToken.CompletionSource.SetException(
                                new SocketException((int)eventArgs.SocketError));
                        }
                        else
                        {
                            eventArgs.MemoryBuffer.CopyTo(asyncReadToken.UserBuffer);
                            ReceiveResult result = new ReceiveResult(eventArgs, asyncReadToken.UserBuffer,
                                eventArgs.BytesTransferred, eventArgs.RemoteEndPoint);

                            asyncReadToken.CompletionSource.SetResult(result);
                        }
                    }

                    receiveBufferPool.Return(asyncReadToken.RentedBuffer, true);

                    break;

                case SocketAsyncOperation.Disconnect:
                    closed = true;
                    break;

                case SocketAsyncOperation.Accept:
                case SocketAsyncOperation.Connect:
                case SocketAsyncOperation.None:
                case SocketAsyncOperation.Receive:
                case SocketAsyncOperation.ReceiveMessageFrom:
                case SocketAsyncOperation.Send:
                case SocketAsyncOperation.SendPackets:
                    throw new NotImplementedException();

                default:
                    throw new ArgumentOutOfRangeException();
            }

            if (closed)
            {
                // handle the client closing the connection on tcp servers at some point
            }
        }

        private Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags,
            Memory<byte> outputBuffer, CancellationToken cancellationToken = default)
        {
            TaskCompletionSource<ReceiveResult> tcs = new TaskCompletionSource<ReceiveResult>();

            byte[] rentedBuffer = receiveBufferPool.Rent(ReceiveResult.PacketSize);
            Memory<byte> memoryBuffer = new Memory<byte>(rentedBuffer);

            SocketAsyncEventArgs args = socketAsyncEventArgsPool.Get();

            args.SetBuffer(memoryBuffer);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            args.UserToken = new AsyncReadToken(rentedBuffer, outputBuffer, tcs, cancellationToken);

            // if the receive operation doesn't complete synchronously, returns the awaitable task
            if (socket.ReceiveFromAsync(args)) return tcs.Task;

            args.MemoryBuffer.CopyTo(outputBuffer);

            ReceiveResult result = new ReceiveResult(args, outputBuffer, args.BytesTransferred, args.RemoteEndPoint);

            receiveBufferPool.Return(rentedBuffer, true);

            return Task.FromResult(result);
        }

        private Task<int> SendAsync(SocketAsyncEventArgs clientArgs, Memory<byte> inputBuffer, SocketFlags socketFlags,
            CancellationToken cancellationToken = default)
        {
            TaskCompletionSource<int> tcs = new TaskCompletionSource<int>();

            byte[] rentedBuffer = sendBufferPool.Rent(ReceiveResult.PacketSize);
            Memory<byte> memoryBuffer = new Memory<byte>(rentedBuffer);

            inputBuffer.CopyTo(memoryBuffer);

            SocketAsyncEventArgs args = clientArgs;
            args.SetBuffer(memoryBuffer);
            args.SocketFlags = socketFlags;
            args.UserToken = new AsyncWriteToken(rentedBuffer, inputBuffer, tcs, cancellationToken);

            // if the send operation doesn't complete synchronously, return the awaitable task
            if (socket.SendToAsync(args)) return tcs.Task;

            int result = args.BytesTransferred;

            sendBufferPool.Return(rentedBuffer, true);
            socketAsyncEventArgsPool.Return(args);

            return Task.FromResult(result);
        }

        public UdpServer() : base(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)
        {
            clients = new ConcurrentDictionary<EndPoint, ConcurrentQueue<byte[]>>();

            for (int i = 0; i < MaxPooledObjects; i++)
            {
                SocketAsyncEventArgs args = new SocketAsyncEventArgs();
                args.Completed += HandleIOCompleted;

                socketAsyncEventArgsPool.Return(args);
            }
        }

        /// <inheritdoc />
        public override async Task StartAsync()
        {
            EndPoint nullEndPoint = new IPEndPoint(IPAddress.Any, 0);
            byte[] receiveBuffer = new byte[ReceiveResult.PacketSize];
            Memory<byte> receiveBufferMemory = new Memory<byte>(receiveBuffer);

            while (true)
            {
                ReceiveResult result = await ReceiveAsync(nullEndPoint, SocketFlags.None, receiveBufferMemory);

                //Console.WriteLine($"[{result.RemoteEndPoint} > Server] : {Encoding.UTF8.GetString(result.Contents.Span)}");

                int sentBytes = await SendAsync(result.ClientArgs, result.Contents, SocketFlags.None);

                //Console.WriteLine($"[Server > {result.RemoteEndPoint}] Sent {sentBytes} bytes to {result.RemoteEndPoint}");
            }
        }

Client Code:

public abstract class SocketClient
    {
        protected readonly Socket socket;

        protected SocketClient(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
        {
            socket = new Socket(addressFamily, socketType, protocolType);
        }

        public bool Bind(in EndPoint localEndPoint)
        {
            try
            {
                socket.Bind(localEndPoint);

                return true;
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex);

                return false;
            }
        }

        public bool Connect(in EndPoint remoteEndPoint)
        {
            try
            {
                socket.Connect(remoteEndPoint);

                return true;
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex);

                return false;
            }
        }

        public abstract Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags,
            Memory<byte> outputBuffer);

        public abstract Task<int> SendAsync(EndPoint remoteEndPoint, Memory<byte> inputBuffer, SocketFlags socketFlags);
    }

public class UdpClient : SocketClient
    {
        /// <inheritdoc />
        public UdpClient() : base(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)
        {
        }

        public override async Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags,
            Memory<byte> outputBuffer)
        {
            byte[] buffer = new byte[ReceiveResult.PacketSize];

            SocketReceiveFromResult result =
                await socket.ReceiveFromAsync(new ArraySegment<byte>(buffer), socketFlags, remoteEndPoint);

            buffer.CopyTo(outputBuffer);

            return new ReceiveResult(default, outputBuffer, result.ReceivedBytes, result.RemoteEndPoint);

            /*
            SocketAsyncEventArgs args = new SocketAsyncEventArgs();
            args.SetBuffer(new byte[ReceiveResult.PacketSize]);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            SocketTask awaitable = new SocketTask(args);

            while (ReceiveResult.PacketSize > args.BytesTransferred)
            {
                await socket.ReceiveFromAsync(awaitable);
            }

            return new ReceiveResult(args.MemoryBuffer, args.RemoteEndPoint);
            */
        }

        public override async Task<int> SendAsync(EndPoint remoteEndPoint, Memory<byte> buffer, SocketFlags socketFlags)
        {
            return await socket.SendToAsync(buffer.ToArray(), socketFlags, remoteEndPoint);

            /*
            SocketAsyncEventArgs args = new SocketAsyncEventArgs();
            args.SetBuffer(buffer);
            args.SocketFlags = socketFlags;
            args.RemoteEndPoint = remoteEndPoint;
            SocketTask awaitable = new SocketTask(args);

            while (buffer.Length > args.BytesTransferred)
            {
                await socket.SendToAsync(awaitable);
            }

            return args.BytesTransferred;
            */
        }
    }

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM