From e9863aef9b462ad3fa78a9533ad50f73a389bd0b Mon Sep 17 00:00:00 2001 From: sternenseemann <0rpkxez4ksa01gb3typccl0i@systemli.org> Date: Sun, 14 Mar 2021 23:37:08 +0100 Subject: [PATCH] Add shutdown operation for Stream sockets shutdown is useful since it can be used to send a TCP FIN which in turn requires the other side to close the connection as well making it possible to gracefully terminate a TCP connection. Windows is untested as of now, but CI should take care of that. Test cases have been added, unfortunately, the exceptions don't seem to line up across platforms, but that doesn't seem to be a design goal of this library. Currently the action is tied to Stream, rather than TCP (which doesn't make a difference currently), as the action is not necessarily TCP specific, but I'm not 100% certain. --- platform/linux/cbits/hs_socket.c | 6 ++ platform/linux/include/hs_socket.h | 2 + .../src/System/Socket/Internal/Platform.hsc | 5 +- platform/win32/cbits/hs_socket.c | 8 +++ platform/win32/include/hs_socket.h | 2 + .../src/System/Socket/Internal/Platform.hsc | 3 + src/System/Socket/Type/Stream.hsc | 27 ++++++++- test/test.hs | 58 ++++++++++++++++++- 8 files changed, 108 insertions(+), 3 deletions(-) diff --git a/platform/linux/cbits/hs_socket.c b/platform/linux/cbits/hs_socket.c index 04c1660..fe57a66 100644 --- a/platform/linux/cbits/hs_socket.c +++ b/platform/linux/cbits/hs_socket.c @@ -111,3 +111,9 @@ int hs_getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen, int *err) *err = errno; return i; } + +int hs_shutdown(int fd, int how, int *err) { + int i = shutdown(fd, how); + *err = errno; + return i; +} diff --git a/platform/linux/include/hs_socket.h b/platform/linux/include/hs_socket.h index d1e0a3b..2fc93be 100644 --- a/platform/linux/include/hs_socket.h +++ b/platform/linux/include/hs_socket.h @@ -28,6 +28,8 @@ int hs_recvfrom(int fd, void *buf, size_t len, int flags, struct sockaddr int hs_getsockopt(int fd, int level, int option_name, void *option_value, int *option_len, int *err); int hs_setsockopt(int fd, int level, int option_name, const void *option_value, int option_len, int *err); +int hs_shutdown(int fd, int how, int *err); + #define SEOK 0 #define SEINTR EINTR #define SEAGAIN EAGAIN diff --git a/platform/linux/src/System/Socket/Internal/Platform.hsc b/platform/linux/src/System/Socket/Internal/Platform.hsc index 38ddd2a..58bac7d 100644 --- a/platform/linux/src/System/Socket/Internal/Platform.hsc +++ b/platform/linux/src/System/Socket/Internal/Platform.hsc @@ -11,7 +11,7 @@ module System.Socket.Internal.Platform ( waitRead, waitWrite, waitConnected, c_socket, c_close, c_connect, c_accept, c_bind, c_listen, c_recv, c_recvfrom, c_send, c_sendto, c_freeaddrinfo, c_getaddrinfo, c_getnameinfo, c_memset, c_gai_strerror, - c_setsockopt, c_getsockopt, c_getsockname) where + c_setsockopt, c_getsockopt, c_getsockname, c_shutdown) where import Control.Monad ( when, unless ) import Control.Concurrent.MVar @@ -114,3 +114,6 @@ foreign import ccall unsafe "gai_strerror" foreign import ccall unsafe "hs_getsockname" c_getsockname :: Fd -> Ptr a -> Ptr CInt -> Ptr CInt -> IO CInt + +foreign import ccall unsafe "hs_shutdown" + c_shutdown :: Fd -> CInt -> Ptr CInt -> IO CInt diff --git a/platform/win32/cbits/hs_socket.c b/platform/win32/cbits/hs_socket.c index 4568dfb..f68225a 100644 --- a/platform/win32/cbits/hs_socket.c +++ b/platform/win32/cbits/hs_socket.c @@ -197,3 +197,11 @@ int hs_getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen, int *err) } return i; } + +int hs_shutdown(int fd, int how, int *err) { + int i = shutdown(fd, how); + if (i != 0) { + *err = WSAGetLastError(); + } + return i; +} diff --git a/platform/win32/include/hs_socket.h b/platform/win32/include/hs_socket.h index 9a3fc96..b0684c2 100644 --- a/platform/win32/include/hs_socket.h +++ b/platform/win32/include/hs_socket.h @@ -103,6 +103,8 @@ int hs_getnameinfo(const struct sockaddr *sa, int salen, void hs_freeaddrinfo(struct addrinfo *res); +int hs_shutdown(int fd, int how, int *err); + #define SEOK 0 #define SEINTR WSAEINTR #define SEAGAIN WSATRY_AGAIN diff --git a/platform/win32/src/System/Socket/Internal/Platform.hsc b/platform/win32/src/System/Socket/Internal/Platform.hsc index 3ade834..141b400 100644 --- a/platform/win32/src/System/Socket/Internal/Platform.hsc +++ b/platform/win32/src/System/Socket/Internal/Platform.hsc @@ -135,3 +135,6 @@ foreign import ccall safe "hs_getnameinfo" foreign import ccall unsafe "hs_getsockname" c_getsockname :: Fd -> Ptr a -> Ptr CInt -> Ptr CInt -> IO CInt + +foreign import ccall unsafe "hs_shutdown" + c_shutdown :: Fd -> CInt -> Ptr CInt -> IO CInt diff --git a/src/System/Socket/Type/Stream.hsc b/src/System/Socket/Type/Stream.hsc index aec2afb..b511858 100644 --- a/src/System/Socket/Type/Stream.hsc +++ b/src/System/Socket/Type/Stream.hsc @@ -20,14 +20,19 @@ module System.Socket.Type.Stream ( -- ** Specialized receive operations -- *** receiveAll , receiveAll + -- ** Stream socket shutdown + , SocketShutdown (..) + , shutdown ) where +import Control.Concurrent.MVar (withMVar) import Control.Exception (throwIO) import Control.Monad (when) import Data.Int import Data.Word import Data.Monoid import Foreign.Ptr +import Foreign.Storable (peek) import Foreign.Marshal.Alloc import qualified Data.ByteString as BS import qualified Data.ByteString.Unsafe as BS @@ -37,6 +42,7 @@ import qualified Data.ByteString.Lazy as LBS import System.Socket import System.Socket.Unsafe +import System.Socket.Internal.Platform #include "hs_socket.h" @@ -151,4 +157,23 @@ receiveAll sock maxLen flags = collect 0 Data.Monoid.mempty build accum = do return (BB.toLazyByteString accum) -{-# DEPRECATED receiveAll "Semantics will change in the next major release. Don't use it anymore!" #-} \ No newline at end of file +{-# DEPRECATED receiveAll "Semantics will change in the next major release. Don't use it anymore!" #-} + +-- | Type of socket shutdown to perform. +data SocketShutdown + -- | Disallows further reads. All future 'receive' calls will be empty + = ShutdownRead + -- | Disallows further writes. All future 'send' calls will fail + -- with an exception (platform specific: either 'ePipe' or 'eShutdown'). + | ShutdownWrite + -- | Disallow both reading and writing. + | ShutdownReadWrite + deriving (Show, Eq, Ord, Enum) + +-- | Shuts down (part of) a stream connection. +-- See 'SocketShutdown' on the effects of a socket shutdown. +shutdown :: Socket f Stream p -> SocketShutdown -> IO () +shutdown (Socket mfd) how = withMVar mfd $ \fd -> + alloca $ \errPtr -> do + i <- c_shutdown fd (fromIntegral $ fromEnum how) errPtr + when (i /= 0) $ SocketException <$> peek errPtr >>= throwIO diff --git a/test/test.hs b/test/test.hs index 90f10d8..2820839 100644 --- a/test/test.hs +++ b/test/test.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-} module Main where import Control.Concurrent (threadDelay) @@ -33,6 +33,7 @@ main = defaultMain $ testGroup "socket" , group02 , group03 , group07 + , group08 , group80 , group98 , group99 @@ -49,6 +50,17 @@ port = 39000 port6 :: Inet6Port port6 = 39000 +-- exception types are not really portable, e. g. a write on a socket +-- that has been shutdown throws ePipe on POSIX and eShutdown on Windows, +-- so checking for exception equality is not really reasonable until we +-- have an exception compatibility layer +assertThrows :: IO a -> Assertion +assertThrows action = catch action' $ \(e :: SocketException) -> pure () + where action' = do + action + assertFailure $ mconcat + [ "No exception was thrown, a SocketException was expected." ] + group00 :: TestTree group00 = testGroup "accept" [ testGroup "Inet/Stream/TCP" @@ -381,6 +393,50 @@ group07 = testGroup "sendAll/receiveAll" ] ] +group08 :: TestTree +group08 = testGroup "shutdown" + [ shutdownCase "ShutdownRead TCP" ShutdownRead + , shutdownCase "ShutdownWrite TCP" ShutdownWrite + , shutdownCase "ShutdownReadWrite TCP" ShutdownReadWrite + ] + where shutdownIncl :: SocketShutdown -> SocketShutdown -> Bool + shutdownIncl _ ShutdownReadWrite = True + shutdownIncl ShutdownReadWrite _ = True + shutdownIncl a b = a == b + + shutdownCase :: String -> SocketShutdown -> TestTree + shutdownCase name how = testCase name $ bracket + ( do + server <- socket :: IO (Socket Inet Stream TCP) + client <- socket :: IO (Socket Inet Stream TCP) + return (server, client) + ) + ( \(server, client) -> do + close server + close client + ) + ( \(server, client) -> do + let addr = SocketAddressInet inetLoopback port + setSocketOption server (ReuseAddress True) + bind server addr + listen server 5 + serverThread <- async $ do + (peerSock, _) <- accept server + threadDelay 100000 -- give time for shutdown + send peerSock "Hello, Client!" msgNoSignal + void $ receive peerSock 1024 msgNoSignal + threadDelay 100000 + connect client addr + shutdown client how + when (shutdownIncl how ShutdownWrite) + $ assertThrows $ send client "Hello, Server!" msgNoSignal + when (shutdownIncl how ShutdownRead) $ do + -- ShutdownRead results in empty reads, not in an exception + s <- receive client 1024 msgNoSignal + assertEqual "receive empty after shutdown" mempty s + cancel serverThread + ) + group80 :: TestTree group80 = testGroup "setSocketOption" [ testGroup "V6Only" [ testCase "present" $ bracket