diff --git a/spec/client_spec.cr b/spec/client_spec.cr index 63ba866..44b4004 100644 --- a/spec/client_spec.cr +++ b/spec/client_spec.cr @@ -1,5 +1,24 @@ require "./spec_helper" +class RecordingProxyClient < HTTP::Proxy::Client + getter open_calls = 0 + + def initialize + super("127.0.0.1", 8080) + end + + def open(host, port, tls = nil, *, dns_timeout, connect_timeout, read_timeout, write_timeout) : IO + @open_calls += 1 + IO::Memory.new + end +end + +class HTTP::Client + def proxy_io_for_spec + io + end +end + describe HTTP::Proxy::Client do describe "#initialize" do it "with host and port" do @@ -18,6 +37,19 @@ describe HTTP::Proxy::Client do describe "HTTP::Client#proxy=" do context HTTP::Client do + it "keeps using proxy after client reconnect" do + proxy_client = RecordingProxyClient.new + client = HTTP::Client.new("httpbingo.org") + + client.proxy = proxy_client + proxy_client.open_calls.should eq(1) + + client.close + + client.proxy_io_for_spec + proxy_client.open_calls.should eq(2) + end + it "should make HTTP request" do with_proxy_server do |host, port, _username, _password, wants_close| proxy_client = HTTP::Proxy::Client.new(host, port) diff --git a/src/ext/http/client.cr b/src/ext/http/client.cr index 5265da4..8095a06 100644 --- a/src/ext/http/client.cr +++ b/src/ext/http/client.cr @@ -39,5 +39,48 @@ module HTTP request.headers["Proxy-Authorization"] = header end end + + # Keep proxy behavior across reconnects by rebuilding @io via proxy as well. + private def io + current_io = @io + return current_io if current_io + + unless @reconnect + raise "This HTTP::Client cannot be reconnected" + end + + if proxy = @proxy + @io = proxy.open( + host: @host, + port: @port, + tls: @tls, + dns_timeout: @dns_timeout, + connect_timeout: @connect_timeout, + read_timeout: @read_timeout, + write_timeout: @write_timeout + ) + else + hostname = @host.starts_with?('[') && @host.ends_with?(']') ? @host[1..-2] : @host + io = TCPSocket.new(hostname, @port, @dns_timeout, @connect_timeout) + io.read_timeout = @read_timeout if @read_timeout + io.write_timeout = @write_timeout if @write_timeout + io.sync = false + + {% if !flag?(:without_openssl) %} + if tls = @tls + tcp_socket = io + begin + io = OpenSSL::SSL::Socket::Client.new(tcp_socket, context: tls, sync_close: true, hostname: @host.rchop('.')) + rescue exc + # don't leak the TCP socket when the SSL connection failed + tcp_socket.close + raise exc + end + end + {% end %} + + @io = io + end + end end end