Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pkg/tcpip/link/veth/veth.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ func (e *Endpoint) SetMTU(mtu uint32) {

// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
// TODO(b/352384218): Enable CapabilityTXChecksumOffload.
return stack.CapabilityRXChecksumOffload | stack.CapabilitySaveRestore
return stack.CapabilityRXChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityTXChecksumOffload
}

// GSOMaxSize implements stack.GSOEndpoint.
Expand Down
4 changes: 4 additions & 0 deletions pkg/tcpip/network/ipv4/ipv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ func (e *endpoint) forwardPacketWithRoute(route *stack.Route, pkt *stack.PacketB
newHdr.SetChecksum(0)
newHdr.SetChecksum(^newHdr.CalculateChecksum())

if route.RequiresTXTransportChecksum() {
newPkt.CalculateTransportChecksum()
}

switch err := forwardToEp.writePacketPostRouting(route, newPkt, true /* headerIncluded */); err.(type) {
case nil:
return nil
Expand Down
109 changes: 109 additions & 0 deletions pkg/tcpip/network/ipv4/ipv4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4343,3 +4343,112 @@ func TestIcmpRateLimit(t *testing.T) {
})
}
}

func newTCPPacket(t *testing.T, srcAddr, dstAddr tcpip.Address, ttl uint8, tcpChecksum uint16) *stack.PacketBuffer {
t.Helper()
ipHeaderLength := header.IPv4MinimumSize
tcpHeaderLength := header.TCPMinimumSize
totalLength := ipHeaderLength + tcpHeaderLength + 10 // 10 bytes payload
hdr := prependable.New(totalLength)

// Payload
hdr.Prepend(10)
copy(hdr.View(), []byte("1234567890"))

// TCP Header
tcpH := header.TCP(hdr.Prepend(tcpHeaderLength))
tcpH.Encode(&header.TCPFields{
SrcPort: 1234,
DstPort: 80,
SeqNum: 100,
AckNum: 200,
DataOffset: uint8(tcpHeaderLength),
Flags: header.TCPFlagSyn,
WindowSize: 65535,
})
tcpH.SetChecksum(tcpChecksum)

// IP Header
ipH := header.IPv4(hdr.Prepend(ipHeaderLength))
ipH.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLength),
Protocol: uint8(header.TCPProtocolNumber),
TTL: ttl,
SrcAddr: srcAddr,
DstAddr: dstAddr,
})
ipH.SetChecksum(0)
ipH.SetChecksum(^ipH.CalculateChecksum())

pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(hdr.View()),
})
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
return pkt
}

func TestForwardingTCPChecksum(t *testing.T) {
ctx := newTestContext()
defer ctx.cleanup()
s := ctx.s

endpoints := make(map[tcpip.NICID]*channel.Endpoint)
for nicID, addr := range defaultEndpointConfigs {
ep := channel.New(1, ipv4.MaxTotalSize, "")
defer ep.Close()

if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: addr}
if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
endpoints[nicID] = ep
}

s.SetRouteTable([]tcpip.Route{
{
Destination: incomingIPv4Addr.Subnet(),
NIC: incomingNICID,
},
{
Destination: outgoingIPv4Addr.Subnet(),
NIC: outgoingNICID,
},
})

if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil {
t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err)
}

// Inject a TCP packet with checksum 0 (invalid) into incoming NIC.
requestPkt := newTCPPacket(t, remoteIPv4Addr1, remoteIPv4Addr2, 64, 0)
defer requestPkt.DecRef()

incomingEndpoint := endpoints[incomingNICID]
incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt)

outgoingEndpoint := endpoints[outgoingNICID]
reply := outgoingEndpoint.Read()
if reply == nil {
t.Fatal("Expected forwarded TCP packet through outgoing NIC")
}
defer reply.DecRef()

// Verify that the forwarded packet has a valid TCP checksum.
payload := stack.PayloadSince(reply.NetworkHeader())
defer payload.Release()

ipv4Header := header.IPv4(payload.AsSlice())
tcpHeaderBytes := ipv4Header.Payload()
tcpHeader := header.TCP(tcpHeaderBytes)

src := ipv4Header.SourceAddress()
dst := ipv4Header.DestinationAddress()
payloadLength := uint16(len(tcpHeader.Payload()))
payloadCsum := checksum.Checksum(tcpHeader.Payload(), 0)
if !tcpHeader.IsChecksumValid(src, dst, payloadCsum, payloadLength) {
t.Errorf("expected valid TCP checksum, but got invalid")
}
}
4 changes: 4 additions & 0 deletions pkg/tcpip/network/ipv6/ipv6.go
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,10 @@ func (e *endpoint) forwardPacketWithRoute(route *stack.Route, pkt *stack.PacketB
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)

if route.RequiresTXTransportChecksum() {
newPkt.CalculateTransportChecksum()
}

forwardToEp, ok := e.protocol.getEndpointForNIC(route.NICID())
if !ok {
// The interface was removed after we obtained the route.
Expand Down
108 changes: 108 additions & 0 deletions pkg/tcpip/network/ipv6/ipv6_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4012,3 +4012,111 @@ func TestRejectMartianMappedPackets(t *testing.T) {
})
}
}

func newTCPPacket6(t *testing.T, srcAddr, dstAddr tcpip.Address, hopLimit uint8, tcpChecksum uint16) *stack.PacketBuffer {
t.Helper()
ipHeaderLength := header.IPv6MinimumSize
tcpHeaderLength := header.TCPMinimumSize
totalLength := ipHeaderLength + tcpHeaderLength + 10 // 10 bytes payload
hdr := prependable.New(totalLength)

// Payload
hdr.Prepend(10)
copy(hdr.View(), []byte("1234567890"))

// TCP Header
tcpH := header.TCP(hdr.Prepend(tcpHeaderLength))
tcpH.Encode(&header.TCPFields{
SrcPort: 1234,
DstPort: 80,
SeqNum: 100,
AckNum: 200,
DataOffset: uint8(tcpHeaderLength),
Flags: header.TCPFlagSyn,
WindowSize: 65535,
})
tcpH.SetChecksum(tcpChecksum)

// IP Header
ipH := header.IPv6(hdr.Prepend(ipHeaderLength))
ipH.Encode(&header.IPv6Fields{
PayloadLength: uint16(tcpHeaderLength + 10),
TransportProtocol: header.TCPProtocolNumber,
HopLimit: hopLimit,
SrcAddr: srcAddr,
DstAddr: dstAddr,
})

pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(hdr.View()),
})
pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
return pkt
}

func TestForwardingTCPChecksum(t *testing.T) {
ctx := newTestContext()
defer ctx.cleanup()
s := ctx.s

endpoints := make(map[tcpip.NICID]*channel.Endpoint)
for nicID, addr := range defaultEndpointConfigs {
ep := channel.New(1, header.IPv6MinimumMTU, "")
defer ep.Close()

if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
addr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: addr}
if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetNICMulticastForwarding(nicID, ProtocolNumber, true /* enabled */)
endpoints[nicID] = ep
}

s.SetRouteTable([]tcpip.Route{
{
Destination: incomingIPv6Addr.Subnet(),
NIC: incomingNICID,
},
{
Destination: outgoingIPv6Addr.Subnet(),
NIC: outgoingNICID,
},
})

if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
}

// Inject a TCP packet with checksum 0 (invalid) into incoming NIC.
requestPkt := newTCPPacket6(t, remoteIPv6Addr1, remoteIPv6Addr2, 64, 0)
defer requestPkt.DecRef()

incomingEndpoint := endpoints[incomingNICID]
incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt)

outgoingEndpoint := endpoints[outgoingNICID]
reply := outgoingEndpoint.Read()
if reply == nil {
t.Fatal("Expected forwarded TCP packet through outgoing NIC")
}
defer reply.DecRef()

// Verify that the forwarded packet has a valid TCP checksum.
payload := stack.PayloadSince(reply.NetworkHeader())
defer payload.Release()

ipv6Header := header.IPv6(payload.AsSlice())
tcpHeaderBytes := ipv6Header.Payload()
tcpHeader := header.TCP(tcpHeaderBytes)

src := ipv6Header.SourceAddress()
dst := ipv6Header.DestinationAddress()
payloadLength := uint16(len(tcpHeader.Payload()))
payloadCsum := checksum.Checksum(tcpHeader.Payload(), 0)
if !tcpHeader.IsChecksumValid(src, dst, payloadCsum, payloadLength) {
t.Errorf("expected valid TCP checksum, but got invalid")
}
}
77 changes: 77 additions & 0 deletions pkg/tcpip/stack/packet_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
)

Expand Down Expand Up @@ -993,3 +994,79 @@ func UpdateHeaders(n header.Network, t header.Transport, updateSRCFields, fullCh
n.SetDestinationAddress(newAddr)
}
}

// CalculateTransportChecksum calculates the transport-layer checksum of the
// packet.
func (pk *PacketBuffer) CalculateTransportChecksum() {
netHdr, transHdr, isICMPError, ok := pk.GetHeaders()
if isICMPError {
// Skip ICMP errors because GetHeaders() returns inner headers, but pk.Data()
// contains the outer payload (including inner IP header), which would
// corrupt the checksum calculation if used as the transport payload.
// Inner headers are already incrementally updated by NAT if needed.
// This aligns with Linux, which also relies on incremental updates for
// inner headers and does not perform full recalculation from scratch.
return
}
if !ok {
// Try to parse headers from Data if not set (e.g., forwarded packet).
if pk.NetworkProtocolNumber == 0 {
return
}
netHdr = pk.Network()
transProto := netHdr.TransportProtocol()

var headerSize int
switch transProto {
case header.TCPProtocolNumber:
// Peek at minimum TCP header to find data offset (which includes options).
b, ok := pk.Data().PullUp(header.TCPMinimumSize)
if !ok {
return
}
tcp := header.TCP(b)
headerSize = int(tcp.DataOffset())
if headerSize < header.TCPMinimumSize {
return
}
case header.UDPProtocolNumber:
headerSize = header.UDPMinimumSize
default:
return
}

// Consume the transport header.
if _, ok := pk.TransportHeader().Consume(headerSize); !ok {
return
}
pk.TransportProtocolNumber = transProto

// Refresh headers.
netHdr, transHdr, isICMPError, ok = pk.GetHeaders()
if !ok || isICMPError {
return
}
}

var xsum uint16
switch t := transHdr.(type) {
case header.TCP:
src := netHdr.SourceAddress()
dst := netHdr.DestinationAddress()
proto := netHdr.TransportProtocol()
totalLen := uint16(len(t) + pk.Data().Size())
xsum = header.PseudoHeaderChecksum(proto, src, dst, totalLen)
xsum = checksum.Combine(xsum, pk.Data().Checksum())
t.SetChecksum(0)
t.SetChecksum(^t.CalculateChecksum(xsum))
case header.UDP:
src := netHdr.SourceAddress()
dst := netHdr.DestinationAddress()
proto := netHdr.TransportProtocol()
totalLen := uint16(len(t) + pk.Data().Size())
xsum = header.PseudoHeaderChecksum(proto, src, dst, totalLen)
xsum = checksum.Combine(xsum, pk.Data().Checksum())
t.SetChecksum(0)
t.SetChecksum(^t.CalculateChecksum(xsum))
}
}
Loading