Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions platform/linux/cbits/hs_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,9 @@ int hs_getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen, int *err)
*err = errno;
return i;
}

int hs_shutdown(int fd, int how, int *err) {
int i = shutdown(fd, how);
*err = errno;
return i;
}
2 changes: 2 additions & 0 deletions platform/linux/include/hs_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ int hs_recvfrom(int fd, void *buf, size_t len, int flags, struct sockaddr
int hs_getsockopt(int fd, int level, int option_name, void *option_value, int *option_len, int *err);
int hs_setsockopt(int fd, int level, int option_name, const void *option_value, int option_len, int *err);

int hs_shutdown(int fd, int how, int *err);

#define SEOK 0
#define SEINTR EINTR
#define SEAGAIN EAGAIN
Expand Down
5 changes: 4 additions & 1 deletion platform/linux/src/System/Socket/Internal/Platform.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module System.Socket.Internal.Platform
( waitRead, waitWrite, waitConnected, c_socket, c_close, c_connect,
c_accept, c_bind, c_listen, c_recv, c_recvfrom, c_send, c_sendto,
c_freeaddrinfo, c_getaddrinfo, c_getnameinfo, c_memset, c_gai_strerror,
c_setsockopt, c_getsockopt, c_getsockname) where
c_setsockopt, c_getsockopt, c_getsockname, c_shutdown) where

import Control.Monad ( when, unless )
import Control.Concurrent.MVar
Expand Down Expand Up @@ -114,3 +114,6 @@ foreign import ccall unsafe "gai_strerror"

foreign import ccall unsafe "hs_getsockname"
c_getsockname :: Fd -> Ptr a -> Ptr CInt -> Ptr CInt -> IO CInt

foreign import ccall unsafe "hs_shutdown"
c_shutdown :: Fd -> CInt -> Ptr CInt -> IO CInt
8 changes: 8 additions & 0 deletions platform/win32/cbits/hs_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,11 @@ int hs_getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen, int *err)
}
return i;
}

int hs_shutdown(int fd, int how, int *err) {
int i = shutdown(fd, how);
if (i != 0) {
*err = WSAGetLastError();
}
return i;
}
2 changes: 2 additions & 0 deletions platform/win32/include/hs_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ int hs_getnameinfo(const struct sockaddr *sa, int salen,

void hs_freeaddrinfo(struct addrinfo *res);

int hs_shutdown(int fd, int how, int *err);

#define SEOK 0
#define SEINTR WSAEINTR
#define SEAGAIN WSATRY_AGAIN
Expand Down
3 changes: 3 additions & 0 deletions platform/win32/src/System/Socket/Internal/Platform.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,6 @@ foreign import ccall safe "hs_getnameinfo"

foreign import ccall unsafe "hs_getsockname"
c_getsockname :: Fd -> Ptr a -> Ptr CInt -> Ptr CInt -> IO CInt

foreign import ccall unsafe "hs_shutdown"
c_shutdown :: Fd -> CInt -> Ptr CInt -> IO CInt
27 changes: 26 additions & 1 deletion src/System/Socket/Type/Stream.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ module System.Socket.Type.Stream (
-- ** Specialized receive operations
-- *** receiveAll
, receiveAll
-- ** Stream socket shutdown
, SocketShutdown (..)
, shutdown
) where

import Control.Concurrent.MVar (withMVar)
import Control.Exception (throwIO)
import Control.Monad (when)
import Data.Int
import Data.Word
import Data.Monoid
import Foreign.Ptr
import Foreign.Storable (peek)
import Foreign.Marshal.Alloc
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
Expand All @@ -37,6 +42,7 @@ import qualified Data.ByteString.Lazy as LBS

import System.Socket
import System.Socket.Unsafe
import System.Socket.Internal.Platform

#include "hs_socket.h"

Expand Down Expand Up @@ -151,4 +157,23 @@ receiveAll sock maxLen flags = collect 0 Data.Monoid.mempty
build accum = do
return (BB.toLazyByteString accum)

{-# DEPRECATED receiveAll "Semantics will change in the next major release. Don't use it anymore!" #-}
{-# DEPRECATED receiveAll "Semantics will change in the next major release. Don't use it anymore!" #-}

-- | Type of socket shutdown to perform.
data SocketShutdown
-- | Disallows further reads. All future 'receive' calls will be empty
= ShutdownRead
-- | Disallows further writes. All future 'send' calls will fail
-- with an exception (platform specific: either 'ePipe' or 'eShutdown').
| ShutdownWrite
-- | Disallow both reading and writing.
| ShutdownReadWrite
deriving (Show, Eq, Ord, Enum)

-- | Shuts down (part of) a stream connection.
-- See 'SocketShutdown' on the effects of a socket shutdown.
shutdown :: Socket f Stream p -> SocketShutdown -> IO ()
shutdown (Socket mfd) how = withMVar mfd $ \fd ->
alloca $ \errPtr -> do
i <- c_shutdown fd (fromIntegral $ fromEnum how) errPtr
when (i /= 0) $ SocketException <$> peek errPtr >>= throwIO
58 changes: 57 additions & 1 deletion test/test.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-}
module Main where

import Control.Concurrent (threadDelay)
Expand Down Expand Up @@ -33,6 +33,7 @@ main = defaultMain $ testGroup "socket"
, group02
, group03
, group07
, group08
, group80
, group98
, group99
Expand All @@ -49,6 +50,17 @@ port = 39000
port6 :: Inet6Port
port6 = 39000

-- exception types are not really portable, e. g. a write on a socket
-- that has been shutdown throws ePipe on POSIX and eShutdown on Windows,
-- so checking for exception equality is not really reasonable until we
-- have an exception compatibility layer
assertThrows :: IO a -> Assertion
assertThrows action = catch action' $ \(e :: SocketException) -> pure ()
where action' = do
action
assertFailure $ mconcat
[ "No exception was thrown, a SocketException was expected." ]

group00 :: TestTree
group00 = testGroup "accept"
[ testGroup "Inet/Stream/TCP"
Expand Down Expand Up @@ -381,6 +393,50 @@ group07 = testGroup "sendAll/receiveAll"
]
]

group08 :: TestTree
group08 = testGroup "shutdown"
[ shutdownCase "ShutdownRead TCP" ShutdownRead
, shutdownCase "ShutdownWrite TCP" ShutdownWrite
, shutdownCase "ShutdownReadWrite TCP" ShutdownReadWrite
]
where shutdownIncl :: SocketShutdown -> SocketShutdown -> Bool
shutdownIncl _ ShutdownReadWrite = True
shutdownIncl ShutdownReadWrite _ = True
shutdownIncl a b = a == b

shutdownCase :: String -> SocketShutdown -> TestTree
shutdownCase name how = testCase name $ bracket
( do
server <- socket :: IO (Socket Inet Stream TCP)
client <- socket :: IO (Socket Inet Stream TCP)
return (server, client)
)
( \(server, client) -> do
close server
close client
)
( \(server, client) -> do
let addr = SocketAddressInet inetLoopback port
setSocketOption server (ReuseAddress True)
bind server addr
listen server 5
serverThread <- async $ do
(peerSock, _) <- accept server
threadDelay 100000 -- give time for shutdown
send peerSock "Hello, Client!" msgNoSignal
void $ receive peerSock 1024 msgNoSignal
threadDelay 100000
connect client addr
shutdown client how
when (shutdownIncl how ShutdownWrite)
$ assertThrows $ send client "Hello, Server!" msgNoSignal
when (shutdownIncl how ShutdownRead) $ do
-- ShutdownRead results in empty reads, not in an exception
s <- receive client 1024 msgNoSignal
assertEqual "receive empty after shutdown" mempty s
cancel serverThread
)

group80 :: TestTree
group80 = testGroup "setSocketOption" [ testGroup "V6Only"
[ testCase "present" $ bracket
Expand Down