diff --git a/platform/linux/cbits/hs_socket.c b/platform/linux/cbits/hs_socket.c index 04c1660..fe57a66 100644 --- a/platform/linux/cbits/hs_socket.c +++ b/platform/linux/cbits/hs_socket.c @@ -111,3 +111,9 @@ int hs_getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen, int *err) *err = errno; return i; } + +int hs_shutdown(int fd, int how, int *err) { + int i = shutdown(fd, how); + *err = errno; + return i; +} diff --git a/platform/linux/include/hs_socket.h b/platform/linux/include/hs_socket.h index d1e0a3b..2fc93be 100644 --- a/platform/linux/include/hs_socket.h +++ b/platform/linux/include/hs_socket.h @@ -28,6 +28,8 @@ int hs_recvfrom(int fd, void *buf, size_t len, int flags, struct sockaddr int hs_getsockopt(int fd, int level, int option_name, void *option_value, int *option_len, int *err); int hs_setsockopt(int fd, int level, int option_name, const void *option_value, int option_len, int *err); +int hs_shutdown(int fd, int how, int *err); + #define SEOK 0 #define SEINTR EINTR #define SEAGAIN EAGAIN diff --git a/platform/linux/src/System/Socket/Internal/Platform.hsc b/platform/linux/src/System/Socket/Internal/Platform.hsc index 38ddd2a..58bac7d 100644 --- a/platform/linux/src/System/Socket/Internal/Platform.hsc +++ b/platform/linux/src/System/Socket/Internal/Platform.hsc @@ -11,7 +11,7 @@ module System.Socket.Internal.Platform ( waitRead, waitWrite, waitConnected, c_socket, c_close, c_connect, c_accept, c_bind, c_listen, c_recv, c_recvfrom, c_send, c_sendto, c_freeaddrinfo, c_getaddrinfo, c_getnameinfo, c_memset, c_gai_strerror, - c_setsockopt, c_getsockopt, c_getsockname) where + c_setsockopt, c_getsockopt, c_getsockname, c_shutdown) where import Control.Monad ( when, unless ) import Control.Concurrent.MVar @@ -114,3 +114,6 @@ foreign import ccall unsafe "gai_strerror" foreign import ccall unsafe "hs_getsockname" c_getsockname :: Fd -> Ptr a -> Ptr CInt -> Ptr CInt -> IO CInt + +foreign import ccall unsafe "hs_shutdown" + c_shutdown :: Fd -> CInt -> Ptr CInt -> IO CInt diff --git a/platform/win32/cbits/hs_socket.c b/platform/win32/cbits/hs_socket.c index 4568dfb..f68225a 100644 --- a/platform/win32/cbits/hs_socket.c +++ b/platform/win32/cbits/hs_socket.c @@ -197,3 +197,11 @@ int hs_getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen, int *err) } return i; } + +int hs_shutdown(int fd, int how, int *err) { + int i = shutdown(fd, how); + if (i != 0) { + *err = WSAGetLastError(); + } + return i; +} diff --git a/platform/win32/include/hs_socket.h b/platform/win32/include/hs_socket.h index 9a3fc96..b0684c2 100644 --- a/platform/win32/include/hs_socket.h +++ b/platform/win32/include/hs_socket.h @@ -103,6 +103,8 @@ int hs_getnameinfo(const struct sockaddr *sa, int salen, void hs_freeaddrinfo(struct addrinfo *res); +int hs_shutdown(int fd, int how, int *err); + #define SEOK 0 #define SEINTR WSAEINTR #define SEAGAIN WSATRY_AGAIN diff --git a/platform/win32/src/System/Socket/Internal/Platform.hsc b/platform/win32/src/System/Socket/Internal/Platform.hsc index 3ade834..141b400 100644 --- a/platform/win32/src/System/Socket/Internal/Platform.hsc +++ b/platform/win32/src/System/Socket/Internal/Platform.hsc @@ -135,3 +135,6 @@ foreign import ccall safe "hs_getnameinfo" foreign import ccall unsafe "hs_getsockname" c_getsockname :: Fd -> Ptr a -> Ptr CInt -> Ptr CInt -> IO CInt + +foreign import ccall unsafe "hs_shutdown" + c_shutdown :: Fd -> CInt -> Ptr CInt -> IO CInt diff --git a/src/System/Socket/Type/Stream.hsc b/src/System/Socket/Type/Stream.hsc index aec2afb..b511858 100644 --- a/src/System/Socket/Type/Stream.hsc +++ b/src/System/Socket/Type/Stream.hsc @@ -20,14 +20,19 @@ module System.Socket.Type.Stream ( -- ** Specialized receive operations -- *** receiveAll , receiveAll + -- ** Stream socket shutdown + , SocketShutdown (..) + , shutdown ) where +import Control.Concurrent.MVar (withMVar) import Control.Exception (throwIO) import Control.Monad (when) import Data.Int import Data.Word import Data.Monoid import Foreign.Ptr +import Foreign.Storable (peek) import Foreign.Marshal.Alloc import qualified Data.ByteString as BS import qualified Data.ByteString.Unsafe as BS @@ -37,6 +42,7 @@ import qualified Data.ByteString.Lazy as LBS import System.Socket import System.Socket.Unsafe +import System.Socket.Internal.Platform #include "hs_socket.h" @@ -151,4 +157,23 @@ receiveAll sock maxLen flags = collect 0 Data.Monoid.mempty build accum = do return (BB.toLazyByteString accum) -{-# DEPRECATED receiveAll "Semantics will change in the next major release. Don't use it anymore!" #-} \ No newline at end of file +{-# DEPRECATED receiveAll "Semantics will change in the next major release. Don't use it anymore!" #-} + +-- | Type of socket shutdown to perform. +data SocketShutdown + -- | Disallows further reads. All future 'receive' calls will be empty + = ShutdownRead + -- | Disallows further writes. All future 'send' calls will fail + -- with an exception (platform specific: either 'ePipe' or 'eShutdown'). + | ShutdownWrite + -- | Disallow both reading and writing. + | ShutdownReadWrite + deriving (Show, Eq, Ord, Enum) + +-- | Shuts down (part of) a stream connection. +-- See 'SocketShutdown' on the effects of a socket shutdown. +shutdown :: Socket f Stream p -> SocketShutdown -> IO () +shutdown (Socket mfd) how = withMVar mfd $ \fd -> + alloca $ \errPtr -> do + i <- c_shutdown fd (fromIntegral $ fromEnum how) errPtr + when (i /= 0) $ SocketException <$> peek errPtr >>= throwIO diff --git a/test/test.hs b/test/test.hs index 90f10d8..2820839 100644 --- a/test/test.hs +++ b/test/test.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-} module Main where import Control.Concurrent (threadDelay) @@ -33,6 +33,7 @@ main = defaultMain $ testGroup "socket" , group02 , group03 , group07 + , group08 , group80 , group98 , group99 @@ -49,6 +50,17 @@ port = 39000 port6 :: Inet6Port port6 = 39000 +-- exception types are not really portable, e. g. a write on a socket +-- that has been shutdown throws ePipe on POSIX and eShutdown on Windows, +-- so checking for exception equality is not really reasonable until we +-- have an exception compatibility layer +assertThrows :: IO a -> Assertion +assertThrows action = catch action' $ \(e :: SocketException) -> pure () + where action' = do + action + assertFailure $ mconcat + [ "No exception was thrown, a SocketException was expected." ] + group00 :: TestTree group00 = testGroup "accept" [ testGroup "Inet/Stream/TCP" @@ -381,6 +393,50 @@ group07 = testGroup "sendAll/receiveAll" ] ] +group08 :: TestTree +group08 = testGroup "shutdown" + [ shutdownCase "ShutdownRead TCP" ShutdownRead + , shutdownCase "ShutdownWrite TCP" ShutdownWrite + , shutdownCase "ShutdownReadWrite TCP" ShutdownReadWrite + ] + where shutdownIncl :: SocketShutdown -> SocketShutdown -> Bool + shutdownIncl _ ShutdownReadWrite = True + shutdownIncl ShutdownReadWrite _ = True + shutdownIncl a b = a == b + + shutdownCase :: String -> SocketShutdown -> TestTree + shutdownCase name how = testCase name $ bracket + ( do + server <- socket :: IO (Socket Inet Stream TCP) + client <- socket :: IO (Socket Inet Stream TCP) + return (server, client) + ) + ( \(server, client) -> do + close server + close client + ) + ( \(server, client) -> do + let addr = SocketAddressInet inetLoopback port + setSocketOption server (ReuseAddress True) + bind server addr + listen server 5 + serverThread <- async $ do + (peerSock, _) <- accept server + threadDelay 100000 -- give time for shutdown + send peerSock "Hello, Client!" msgNoSignal + void $ receive peerSock 1024 msgNoSignal + threadDelay 100000 + connect client addr + shutdown client how + when (shutdownIncl how ShutdownWrite) + $ assertThrows $ send client "Hello, Server!" msgNoSignal + when (shutdownIncl how ShutdownRead) $ do + -- ShutdownRead results in empty reads, not in an exception + s <- receive client 1024 msgNoSignal + assertEqual "receive empty after shutdown" mempty s + cancel serverThread + ) + group80 :: TestTree group80 = testGroup "setSocketOption" [ testGroup "V6Only" [ testCase "present" $ bracket