From b000ea2df8c8f45760b7a794863b9a2d756fc8ab Mon Sep 17 00:00:00 2001 From: Jing Chen Date: Fri, 22 May 2026 15:54:59 -0700 Subject: [PATCH] Enable CapabilityTXChecksumOffload for veth devices. PiperOrigin-RevId: 919899576 --- pkg/tcpip/link/veth/veth.go | 3 +- pkg/tcpip/network/ipv4/ipv4.go | 4 + pkg/tcpip/network/ipv4/ipv4_test.go | 109 ++++++++++++++++++++++++++++ pkg/tcpip/network/ipv6/ipv6.go | 4 + pkg/tcpip/network/ipv6/ipv6_test.go | 108 +++++++++++++++++++++++++++ pkg/tcpip/stack/packet_buffer.go | 77 ++++++++++++++++++++ 6 files changed, 303 insertions(+), 2 deletions(-) diff --git a/pkg/tcpip/link/veth/veth.go b/pkg/tcpip/link/veth/veth.go index c22f8e9609..261cbcb41e 100644 --- a/pkg/tcpip/link/veth/veth.go +++ b/pkg/tcpip/link/veth/veth.go @@ -162,8 +162,7 @@ func (e *Endpoint) SetMTU(mtu uint32) { // Capabilities implements stack.LinkEndpoint.Capabilities. func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - // TODO(b/352384218): Enable CapabilityTXChecksumOffload. - return stack.CapabilityRXChecksumOffload | stack.CapabilitySaveRestore + return stack.CapabilityRXChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityTXChecksumOffload } // GSOMaxSize implements stack.GSOEndpoint. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 714cf980d7..3fd83618e4 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -747,6 +747,10 @@ func (e *endpoint) forwardPacketWithRoute(route *stack.Route, pkt *stack.PacketB newHdr.SetChecksum(0) newHdr.SetChecksum(^newHdr.CalculateChecksum()) + if route.RequiresTXTransportChecksum() { + newPkt.CalculateTransportChecksum() + } + switch err := forwardToEp.writePacketPostRouting(route, newPkt, true /* headerIncluded */); err.(type) { case nil: return nil diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index df69487c29..522a58dbc3 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -4343,3 +4343,112 @@ func TestIcmpRateLimit(t *testing.T) { }) } } + +func newTCPPacket(t *testing.T, srcAddr, dstAddr tcpip.Address, ttl uint8, tcpChecksum uint16) *stack.PacketBuffer { + t.Helper() + ipHeaderLength := header.IPv4MinimumSize + tcpHeaderLength := header.TCPMinimumSize + totalLength := ipHeaderLength + tcpHeaderLength + 10 // 10 bytes payload + hdr := prependable.New(totalLength) + + // Payload + hdr.Prepend(10) + copy(hdr.View(), []byte("1234567890")) + + // TCP Header + tcpH := header.TCP(hdr.Prepend(tcpHeaderLength)) + tcpH.Encode(&header.TCPFields{ + SrcPort: 1234, + DstPort: 80, + SeqNum: 100, + AckNum: 200, + DataOffset: uint8(tcpHeaderLength), + Flags: header.TCPFlagSyn, + WindowSize: 65535, + }) + tcpH.SetChecksum(tcpChecksum) + + // IP Header + ipH := header.IPv4(hdr.Prepend(ipHeaderLength)) + ipH.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.TCPProtocolNumber), + TTL: ttl, + SrcAddr: srcAddr, + DstAddr: dstAddr, + }) + ipH.SetChecksum(0) + ipH.SetChecksum(^ipH.CalculateChecksum()) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + return pkt +} + +func TestForwardingTCPChecksum(t *testing.T) { + ctx := newTestContext() + defer ctx.cleanup() + s := ctx.s + + endpoints := make(map[tcpip.NICID]*channel.Endpoint) + for nicID, addr := range defaultEndpointConfigs { + ep := channel.New(1, ipv4.MaxTotalSize, "") + defer ep.Close() + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) + } + endpoints[nicID] = ep + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: incomingIPv4Addr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: outgoingIPv4Addr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) + } + + // Inject a TCP packet with checksum 0 (invalid) into incoming NIC. + requestPkt := newTCPPacket(t, remoteIPv4Addr1, remoteIPv4Addr2, 64, 0) + defer requestPkt.DecRef() + + incomingEndpoint := endpoints[incomingNICID] + incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + outgoingEndpoint := endpoints[outgoingNICID] + reply := outgoingEndpoint.Read() + if reply == nil { + t.Fatal("Expected forwarded TCP packet through outgoing NIC") + } + defer reply.DecRef() + + // Verify that the forwarded packet has a valid TCP checksum. + payload := stack.PayloadSince(reply.NetworkHeader()) + defer payload.Release() + + ipv4Header := header.IPv4(payload.AsSlice()) + tcpHeaderBytes := ipv4Header.Payload() + tcpHeader := header.TCP(tcpHeaderBytes) + + src := ipv4Header.SourceAddress() + dst := ipv4Header.DestinationAddress() + payloadLength := uint16(len(tcpHeader.Payload())) + payloadCsum := checksum.Checksum(tcpHeader.Payload(), 0) + if !tcpHeader.IsChecksumValid(src, dst, payloadCsum, payloadLength) { + t.Errorf("expected valid TCP checksum, but got invalid") + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b4a38a5669..57648cf86e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1095,6 +1095,10 @@ func (e *endpoint) forwardPacketWithRoute(route *stack.Route, pkt *stack.PacketB // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) + if route.RequiresTXTransportChecksum() { + newPkt.CalculateTransportChecksum() + } + forwardToEp, ok := e.protocol.getEndpointForNIC(route.NICID()) if !ok { // The interface was removed after we obtained the route. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index f5af09a230..76a49ac4b4 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -4012,3 +4012,111 @@ func TestRejectMartianMappedPackets(t *testing.T) { }) } } + +func newTCPPacket6(t *testing.T, srcAddr, dstAddr tcpip.Address, hopLimit uint8, tcpChecksum uint16) *stack.PacketBuffer { + t.Helper() + ipHeaderLength := header.IPv6MinimumSize + tcpHeaderLength := header.TCPMinimumSize + totalLength := ipHeaderLength + tcpHeaderLength + 10 // 10 bytes payload + hdr := prependable.New(totalLength) + + // Payload + hdr.Prepend(10) + copy(hdr.View(), []byte("1234567890")) + + // TCP Header + tcpH := header.TCP(hdr.Prepend(tcpHeaderLength)) + tcpH.Encode(&header.TCPFields{ + SrcPort: 1234, + DstPort: 80, + SeqNum: 100, + AckNum: 200, + DataOffset: uint8(tcpHeaderLength), + Flags: header.TCPFlagSyn, + WindowSize: 65535, + }) + tcpH.SetChecksum(tcpChecksum) + + // IP Header + ipH := header.IPv6(hdr.Prepend(ipHeaderLength)) + ipH.Encode(&header.IPv6Fields{ + PayloadLength: uint16(tcpHeaderLength + 10), + TransportProtocol: header.TCPProtocolNumber, + HopLimit: hopLimit, + SrcAddr: srcAddr, + DstAddr: dstAddr, + }) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(hdr.View()), + }) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + return pkt +} + +func TestForwardingTCPChecksum(t *testing.T) { + ctx := newTestContext() + defer ctx.cleanup() + s := ctx.s + + endpoints := make(map[tcpip.NICID]*channel.Endpoint) + for nicID, addr := range defaultEndpointConfigs { + ep := channel.New(1, header.IPv6MinimumMTU, "") + defer ep.Close() + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) + } + s.SetNICMulticastForwarding(nicID, ProtocolNumber, true /* enabled */) + endpoints[nicID] = ep + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: incomingIPv6Addr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: outgoingIPv6Addr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } + + // Inject a TCP packet with checksum 0 (invalid) into incoming NIC. + requestPkt := newTCPPacket6(t, remoteIPv6Addr1, remoteIPv6Addr2, 64, 0) + defer requestPkt.DecRef() + + incomingEndpoint := endpoints[incomingNICID] + incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt) + + outgoingEndpoint := endpoints[outgoingNICID] + reply := outgoingEndpoint.Read() + if reply == nil { + t.Fatal("Expected forwarded TCP packet through outgoing NIC") + } + defer reply.DecRef() + + // Verify that the forwarded packet has a valid TCP checksum. + payload := stack.PayloadSince(reply.NetworkHeader()) + defer payload.Release() + + ipv6Header := header.IPv6(payload.AsSlice()) + tcpHeaderBytes := ipv6Header.Payload() + tcpHeader := header.TCP(tcpHeaderBytes) + + src := ipv6Header.SourceAddress() + dst := ipv6Header.DestinationAddress() + payloadLength := uint16(len(tcpHeader.Payload())) + payloadCsum := checksum.Checksum(tcpHeader.Payload(), 0) + if !tcpHeader.IsChecksumValid(src, dst, payloadCsum, payloadLength) { + t.Errorf("expected valid TCP checksum, but got invalid") + } +} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 06ed6e671a..765e3fb09d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -993,3 +994,79 @@ func UpdateHeaders(n header.Network, t header.Transport, updateSRCFields, fullCh n.SetDestinationAddress(newAddr) } } + +// CalculateTransportChecksum calculates the transport-layer checksum of the +// packet. +func (pk *PacketBuffer) CalculateTransportChecksum() { + netHdr, transHdr, isICMPError, ok := pk.GetHeaders() + if isICMPError { + // Skip ICMP errors because GetHeaders() returns inner headers, but pk.Data() + // contains the outer payload (including inner IP header), which would + // corrupt the checksum calculation if used as the transport payload. + // Inner headers are already incrementally updated by NAT if needed. + // This aligns with Linux, which also relies on incremental updates for + // inner headers and does not perform full recalculation from scratch. + return + } + if !ok { + // Try to parse headers from Data if not set (e.g., forwarded packet). + if pk.NetworkProtocolNumber == 0 { + return + } + netHdr = pk.Network() + transProto := netHdr.TransportProtocol() + + var headerSize int + switch transProto { + case header.TCPProtocolNumber: + // Peek at minimum TCP header to find data offset (which includes options). + b, ok := pk.Data().PullUp(header.TCPMinimumSize) + if !ok { + return + } + tcp := header.TCP(b) + headerSize = int(tcp.DataOffset()) + if headerSize < header.TCPMinimumSize { + return + } + case header.UDPProtocolNumber: + headerSize = header.UDPMinimumSize + default: + return + } + + // Consume the transport header. + if _, ok := pk.TransportHeader().Consume(headerSize); !ok { + return + } + pk.TransportProtocolNumber = transProto + + // Refresh headers. + netHdr, transHdr, isICMPError, ok = pk.GetHeaders() + if !ok || isICMPError { + return + } + } + + var xsum uint16 + switch t := transHdr.(type) { + case header.TCP: + src := netHdr.SourceAddress() + dst := netHdr.DestinationAddress() + proto := netHdr.TransportProtocol() + totalLen := uint16(len(t) + pk.Data().Size()) + xsum = header.PseudoHeaderChecksum(proto, src, dst, totalLen) + xsum = checksum.Combine(xsum, pk.Data().Checksum()) + t.SetChecksum(0) + t.SetChecksum(^t.CalculateChecksum(xsum)) + case header.UDP: + src := netHdr.SourceAddress() + dst := netHdr.DestinationAddress() + proto := netHdr.TransportProtocol() + totalLen := uint16(len(t) + pk.Data().Size()) + xsum = header.PseudoHeaderChecksum(proto, src, dst, totalLen) + xsum = checksum.Combine(xsum, pk.Data().Checksum()) + t.SetChecksum(0) + t.SetChecksum(^t.CalculateChecksum(xsum)) + } +}