diff --git a/layers/dns.go b/layers/dns.go index 21b6de6e5..99e1a509c 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -7,6 +7,7 @@ package layers import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -252,6 +253,19 @@ func (doc DNSOpCode) String() string { // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ // DNS contains data from a single Domain Name Service packet. +// +// DNS name fields (such as DNSQuestion.Name and DNSResourceRecord.Name) hold +// names in dotted presentation form. When a packet decoded by this layer is +// re-serialized without changing a name field, the original wire label +// boundaries are preserved, so a label that legitimately contains a literal +// dot (for example a DNS-SD instance label "foo.bar") round-trips correctly. +// +// A name field that is newly constructed or changed is instead parsed as a +// presentation name, where a dot separates labels. To embed a literal dot, +// backslash, or arbitrary byte in a single label, use the escapes "\.", "\\", +// or "\DDD" (three decimal digits). A decoded name copied verbatim into a +// different field loses its preserved boundaries and is re-parsed this way, so +// any literal dot in it must be escaped first. type DNS struct { BaseLayer @@ -284,6 +298,74 @@ type DNS struct { buffer []byte } +type dnsNameLabels [][]byte + +func cloneBytes(b []byte) []byte { + if len(b) == 0 { + return nil + } + return append([]byte(nil), b...) +} + +// dnsNameMeta preserves one decoded DNS name's wire label boundaries together +// with a snapshot of the decoded presentation bytes. Serialization reproduces +// the exact labels only while the public name field still equals orig; once the +// caller changes the name, the metadata is treated as stale. +type dnsNameMeta struct { + labels dnsNameLabels + orig []byte +} + +// dnsRecordNameMeta holds the preserved name metadata for one resource record. +// It is allocated lazily, only when a decoded name actually contains a label +// boundary that the dotted presentation form cannot represent, so the common +// case costs a single nil pointer rather than a metadata field per name. A +// record carries at most its owner name plus the one or two DNS names in its +// RDATA, so rdata holds the single RDATA name (NS, CNAME, PTR, MX, SRV) or the +// SOA MName, and rdata2 holds the SOA RName. +type dnsRecordNameMeta struct { + name dnsNameMeta + rdata dnsNameMeta + rdata2 dnsNameMeta +} + +// newDNSNameMeta builds metadata for a decoded name. Call it only when labels +// is non-nil (the name needs boundary preservation). +func newDNSNameMeta(name []byte, labels dnsNameLabels) dnsNameMeta { + return dnsNameMeta{labels: labels, orig: cloneBytes(name)} +} + +// dnsLabelNeedsPreservation reports whether a wire label cannot be represented +// unambiguously in the dotted presentation Name. '.' and '\' are the only +// metacharacters of presentation form: the label separator and the escape +// introducer (see encodeDNSPresentationName). A label containing either would be +// mis-split or mis-escaped if flattened into Name and re-encoded, so its exact +// wire boundaries must be preserved. Every other byte round-trips literally. +func dnsLabelNeedsPreservation(label []byte) bool { + for _, b := range label { + if b == '.' || b == '\\' { + return true + } + } + return false +} + +func collectDNSWireLabels(data []byte, offset, end int) dnsNameLabels { + var labels dnsNameLabels + for offset < end { + if offset >= len(data) || data[offset]&0xc0 != 0 { + return labels + } + next := offset + int(data[offset]) + 1 + if next > end || next > len(data) { + return labels + } + labels = append(labels, cloneBytes(data[offset+1:next])) + offset = next + } + return labels +} + // LayerType returns gopacket.LayerTypeDNS. func (d *DNS) LayerType() gopacket.LayerType { return LayerTypeDNS } @@ -413,57 +495,202 @@ func b2i(b bool) int { return 0 } -func recSize(rr *DNSResourceRecord) int { +// encodeDNSPresentationName encodes a presentation-form DNS name (where '.' +// separates labels and '\' introduces an escape) into wire format. When data +// is non-nil it writes the encoding starting at offset; when data is nil it only +// measures. Either way it returns the wire size, so a caller sizes a buffer with +// data == nil and then fills it with identical logic: one source of truth for +// the grammar, with no separate size/write passes to keep in sync. +// +// Recognized escapes are \. \\ and \DDD (three decimal digits, value <= 255); +// any other \X is preserved literally as backslash plus X. +func encodeDNSPresentationName(name []byte, data []byte, offset int) (int, error) { + start := offset + if len(name) == 0 || (len(name) == 1 && name[0] == '.') { + if data != nil { + data[offset] = 0x00 + } + return 1, nil + } + labelOffset := offset // reserved slot for the current label's length octet + offset++ + labelLen := 0 + lastWasSeparator := false + for i := 0; i < len(name); i++ { + c := name[i] + if c == '.' { + if labelLen > 63 { + return 0, errDNSNameTooLong + } + if data != nil { + data[labelOffset] = byte(labelLen) + } + labelOffset = offset + offset++ + labelLen = 0 + lastWasSeparator = true + continue + } + if c == '\\' { + if i+1 >= len(name) { + return 0, errDNSNameInvalidIndex + } + next := name[i+1] + switch { + case next == '.' || next == '\\': + if data != nil { + data[offset] = next + } + offset++ + labelLen++ + i++ + case next >= '0' && next <= '9': + if i+3 >= len(name) || name[i+2] < '0' || name[i+2] > '9' || name[i+3] < '0' || name[i+3] > '9' { + return 0, errDNSNameInvalidIndex + } + v := int(name[i+1]-'0')*100 + int(name[i+2]-'0')*10 + int(name[i+3]-'0') + if v > 255 { + return 0, errDNSNameInvalidIndex + } + if data != nil { + data[offset] = byte(v) + } + offset++ + labelLen++ + i += 3 + default: + if data != nil { + data[offset] = c + data[offset+1] = next + } + offset += 2 + labelLen += 2 + i++ + } + lastWasSeparator = false + continue + } + if data != nil { + data[offset] = c + } + offset++ + labelLen++ + lastWasSeparator = false + } + if labelLen > 63 { + return 0, errDNSNameTooLong + } + if !lastWasSeparator { + if data != nil { + data[labelOffset] = byte(labelLen) + } + } else { + offset = labelOffset + } + if data != nil { + data[offset] = 0x00 + } + size := offset + 1 - start + if size > 255 { + return 0, errDNSNameTooLong + } + return size, nil +} + +func dnsNameLabelsSize(labels dnsNameLabels) (int, error) { + size := 1 + for _, label := range labels { + if len(label) > 63 { + return 0, errDNSNameTooLong + } + size += 1 + len(label) + if size > 255 { + return 0, errDNSNameTooLong + } + } + return size, nil +} + +func usePreservedDNSLabels(name []byte, m *dnsNameMeta) bool { + return m != nil && m.labels != nil && bytes.Equal(name, m.orig) +} + +func dnsNameSize(name []byte, m *dnsNameMeta) (int, error) { + if usePreservedDNSLabels(name, m) { + return dnsNameLabelsSize(m.labels) + } + return encodeDNSPresentationName(name, nil, 0) +} + +func recSize(rr *DNSResourceRecord) (int, error) { switch rr.Type { case DNSTypeA: - return 4 + return 4, nil case DNSTypeAAAA: - return 16 + return 16, nil case DNSTypeNS: - return len(rr.NS) + 2 + return dnsNameSize(rr.NS, rr.rdataMeta()) case DNSTypeCNAME: - return len(rr.CNAME) + 2 + return dnsNameSize(rr.CNAME, rr.rdataMeta()) case DNSTypePTR: - return len(rr.PTR) + 2 + return dnsNameSize(rr.PTR, rr.rdataMeta()) case DNSTypeSOA: - return len(rr.SOA.MName) + 2 + len(rr.SOA.RName) + 2 + 20 + mNameSize, err := dnsNameSize(rr.SOA.MName, rr.rdataMeta()) + if err != nil { + return 0, err + } + rNameSize, err := dnsNameSize(rr.SOA.RName, rr.rdata2Meta()) + if err != nil { + return 0, err + } + return mNameSize + rNameSize + 20, nil case DNSTypeMX: - return 2 + len(rr.MX.Name) + 2 + nameSize, err := dnsNameSize(rr.MX.Name, rr.rdataMeta()) + if err != nil { + return 0, err + } + return 2 + nameSize, nil case DNSTypeTXT: l := len(rr.TXTs) for _, txt := range rr.TXTs { l += len(txt) } - return l + return l, nil case DNSTypeSRV: - return 6 + len(rr.SRV.Name) + 2 + nameSize, err := dnsNameSize(rr.SRV.Name, rr.rdataMeta()) + if err != nil { + return 0, err + } + return 6 + nameSize, nil case DNSTypeURI: - return 4 + len(rr.URI.Target) + return 4 + len(rr.URI.Target), nil case DNSTypeOPT: l := len(rr.OPT) * 4 for _, opt := range rr.OPT { l += len(opt.Data) } - return l + return l, nil } - return 0 + return 0, nil } -func computeSize(recs []DNSResourceRecord) int { +func computeSize(recs []DNSResourceRecord) (int, error) { sz := 0 for _, rr := range recs { - v := len(rr.Name) - - if v == 0 { - sz += v + 11 - } else { - sz += v + 12 + v, err := dnsNameSize(rr.Name, rr.ownerMeta()) + if err != nil { + return 0, err } + sz += v + 10 - sz += recSize(&rr) + rSz, err := recSize(&rr) + if err != nil { + return 0, err + } + sz += rSz } - return sz + return sz, nil } // SerializeTo writes the serialized form of this layer into the @@ -471,11 +698,27 @@ func computeSize(recs []DNSResourceRecord) int { func (d *DNS) SerializeTo(b gopacket.SerializeBuffer, opts gopacket.SerializeOptions) error { dsz := 0 for _, q := range d.Questions { - dsz += len(q.Name) + 6 + qSize, err := dnsNameSize(q.Name, q.nameMeta) + if err != nil { + return err + } + dsz += qSize + 4 + } + answersSize, err := computeSize(d.Answers) + if err != nil { + return err + } + dsz += answersSize + authoritiesSize, err := computeSize(d.Authorities) + if err != nil { + return err + } + dsz += authoritiesSize + additionalsSize, err := computeSize(d.Additionals) + if err != nil { + return err } - dsz += computeSize(d.Answers) - dsz += computeSize(d.Authorities) - dsz += computeSize(d.Additionals) + dsz += additionalsSize bytes, err := b.PrependBytes(12 + dsz) if err != nil { @@ -498,7 +741,10 @@ func (d *DNS) SerializeTo(b gopacket.SerializeBuffer, opts gopacket.SerializeOpt off := 12 for _, qd := range d.Questions { - n := qd.encode(bytes, off) + n, err := qd.encode(bytes, off) + if err != nil { + return err + } off += n } @@ -535,18 +781,19 @@ func (d *DNS) SerializeTo(b gopacket.SerializeBuffer, opts gopacket.SerializeOpt const maxRecursionLevel = 255 -func decodeName(data []byte, offset int, buffer *[]byte, level int) ([]byte, int, error) { +func decodeName(data []byte, offset int, buffer *[]byte, level int) ([]byte, dnsNameLabels, int, error) { if level > maxRecursionLevel { - return nil, 0, errMaxRecursion + return nil, nil, 0, errMaxRecursion } else if offset >= len(data) { - return nil, 0, errDNSNameOffsetTooHigh + return nil, nil, 0, errDNSNameOffsetTooHigh } else if offset < 0 { - return nil, 0, errDNSNameOffsetNegative + return nil, nil, 0, errDNSNameOffsetNegative } start := len(*buffer) index := offset + var labels dnsNameLabels if data[index] == 0x00 { - return nil, index + 1, nil + return nil, labels, index + 1, nil } loop: for data[index] != 0x00 { @@ -562,12 +809,18 @@ loop: */ index2 := index + int(data[index]) + 1 if index2-offset > 255 { - return nil, 0, errDNSNameTooLong + return nil, nil, 0, errDNSNameTooLong } else if index2 < index+1 || index2 > len(data) { - return nil, 0, errDNSNameInvalidIndex + return nil, nil, 0, errDNSNameInvalidIndex } + label := data[index+1 : index2] *buffer = append(*buffer, '.') - *buffer = append(*buffer, data[index+1:index2]...) + *buffer = append(*buffer, label...) + if labels != nil { + labels = append(labels, cloneBytes(label)) + } else if dnsLabelNeedsPreservation(label) { + labels = collectDNSWireLabels(data, offset, index2) + } index = index2 case 0xc0: @@ -591,39 +844,49 @@ loop: - a sequence of labels ending with a pointer */ if index+2 > len(data) { - return nil, 0, errDNSPointerOffsetTooHigh + return nil, nil, 0, errDNSPointerOffsetTooHigh } offsetp := int(binary.BigEndian.Uint16(data[index:index+2]) & 0x3fff) if offsetp > len(data) { - return nil, 0, errDNSPointerOffsetTooHigh + return nil, nil, 0, errDNSPointerOffsetTooHigh } // This looks a little tricky, but actually isn't. Because of how // decodeName is written, calling it appends the decoded name to the // current buffer. We already have the start of the buffer, then, so // once this call is done buffer[start:] will contain our full name. - _, _, err := decodeName(data, offsetp, buffer, level+1) + pointedName, pointedLabels, _, err := decodeName(data, offsetp, buffer, level+1) if err != nil { - return nil, 0, err + return nil, nil, 0, err + } + if pointedLabels != nil { + if labels == nil { + labels = collectDNSWireLabels(data, offset, index) + } + labels = append(labels, pointedLabels...) + } else if labels != nil && len(pointedName) > 0 { + for _, lbl := range bytes.Split(pointedName, []byte{'.'}) { + labels = append(labels, cloneBytes(lbl)) + } } index++ // pointer is two bytes, so add an extra byte here. break loop /* EDNS, or other DNS option ? */ case 0x40: // RFC 2673 - return nil, 0, fmt.Errorf("qname '0x40' - RFC 2673 unsupported yet (data=%x index=%d)", + return nil, nil, 0, fmt.Errorf("qname '0x40' - RFC 2673 unsupported yet (data=%x index=%d)", data[index], index) case 0x80: - return nil, 0, fmt.Errorf("qname '0x80' unsupported yet (data=%x index=%d)", + return nil, nil, 0, fmt.Errorf("qname '0x80' unsupported yet (data=%x index=%d)", data[index], index) } if index >= len(data) { - return nil, 0, errDNSIndexOutOfRange + return nil, nil, 0, errDNSIndexOutOfRange } } if len(*buffer) <= start { - return (*buffer)[start:], index + 1, nil + return (*buffer)[start:], labels, index + 1, nil } - return (*buffer)[start+1:], index + 1, nil + return (*buffer)[start+1:], labels, index + 1, nil } // DNSQuestion wraps a single request (question) within a DNS query. @@ -631,10 +894,12 @@ type DNSQuestion struct { Name []byte Type DNSType Class DNSClass + + nameMeta *dnsNameMeta } func (q *DNSQuestion) decode(data []byte, offset int, df gopacket.DecodeFeedback, buffer *[]byte) (int, error) { - name, endq, err := decodeName(data, offset, buffer, 1) + name, labels, endq, err := decodeName(data, offset, buffer, 1) if err != nil { return 0, err } @@ -644,18 +909,25 @@ func (q *DNSQuestion) decode(data []byte, offset int, df gopacket.DecodeFeedback } q.Name = name + if labels != nil { + meta := newDNSNameMeta(name, labels) + q.nameMeta = &meta + } q.Type = DNSType(binary.BigEndian.Uint16(data[endq : endq+2])) q.Class = DNSClass(binary.BigEndian.Uint16(data[endq+2 : endq+4])) return endq + 4, nil } -func (q *DNSQuestion) encode(data []byte, offset int) int { - noff := encodeName(q.Name, data, offset) - nSz := noff - offset +func (q *DNSQuestion) encode(data []byte, offset int) (int, error) { + nSz, err := encodeDNSName(q.Name, q.nameMeta, data, offset) + if err != nil { + return 0, err + } + noff := offset + nSz binary.BigEndian.PutUint16(data[noff:], uint16(q.Type)) binary.BigEndian.PutUint16(data[noff+2:], uint16(q.Class)) - return nSz + 4 + return nSz + 4, nil } // DNSResourceRecord @@ -704,11 +976,47 @@ type DNSResourceRecord struct { // Undecoded TXT for backward compatibility TXT []byte + + names *dnsRecordNameMeta +} + +// ensureNameMeta returns the record's preserved-name metadata block, allocating it +// on first use. Only call it when a decoded name actually needs preservation. +func (rr *DNSResourceRecord) ensureNameMeta() *dnsRecordNameMeta { + if rr.names == nil { + rr.names = &dnsRecordNameMeta{} + } + return rr.names +} + +// ownerMeta, rdataMeta, and rdata2Meta return the preserved metadata for the +// owner name, the single RDATA name (or SOA MName), and the SOA RName. They +// return nil when no preservation metadata was recorded, which the size and +// encode helpers treat as "use presentation form". +func (rr *DNSResourceRecord) ownerMeta() *dnsNameMeta { + if rr.names == nil { + return nil + } + return &rr.names.name +} + +func (rr *DNSResourceRecord) rdataMeta() *dnsNameMeta { + if rr.names == nil { + return nil + } + return &rr.names.rdata +} + +func (rr *DNSResourceRecord) rdata2Meta() *dnsNameMeta { + if rr.names == nil { + return nil + } + return &rr.names.rdata2 } // decode decodes the resource record, returning the total length of the record. func (rr *DNSResourceRecord) decode(data []byte, offset int, df gopacket.DecodeFeedback, buffer *[]byte) (int, error) { - name, endq, err := decodeName(data, offset, buffer, 1) + name, labels, endq, err := decodeName(data, offset, buffer, 1) if err != nil { return 0, err } @@ -718,6 +1026,9 @@ func (rr *DNSResourceRecord) decode(data []byte, offset int, df gopacket.DecodeF } rr.Name = name + if labels != nil { + rr.ensureNameMeta().name = newDNSNameMeta(name, labels) + } rr.Type = DNSType(binary.BigEndian.Uint16(data[endq : endq+2])) rr.Class = DNSClass(binary.BigEndian.Uint16(data[endq+2 : endq+4])) rr.TTL = binary.BigEndian.Uint32(data[endq+4 : endq+8]) @@ -737,34 +1048,39 @@ func (rr *DNSResourceRecord) decode(data []byte, offset int, df gopacket.DecodeF return endq + 10 + int(rr.DataLength), nil } -func encodeName(name []byte, data []byte, offset int) int { - l := 0 - for i := range name { - if name[i] == '.' { - data[offset+i-l] = byte(l) - l = 0 - } else { - // skip one to write the length - data[offset+i+1] = name[i] - l++ - } +func encodeDNSNameLabels(labels dnsNameLabels, data []byte, offset int) (int, error) { + size, err := dnsNameLabelsSize(labels) + if err != nil { + return 0, err } - - if len(name) == 0 { - data[offset] = 0x00 // terminal - return offset + 1 + start := offset + for _, label := range labels { + data[offset] = byte(len(label)) + offset++ + copy(data[offset:], label) + offset += len(label) } + data[offset] = 0x00 + if offset+1-start != size { + return 0, errDNSNameInvalidIndex + } + return size, nil +} - // length for final portion - data[offset+len(name)-l] = byte(l) - data[offset+len(name)+1] = 0x00 // terminal - return offset + len(name) + 2 +func encodeDNSName(name []byte, m *dnsNameMeta, data []byte, offset int) (int, error) { + if usePreservedDNSLabels(name, m) { + return encodeDNSNameLabels(m.labels, data, offset) + } + return encodeDNSPresentationName(name, data, offset) } func (rr *DNSResourceRecord) encode(data []byte, offset int, opts gopacket.SerializeOptions) (int, error) { - noff := encodeName(rr.Name, data, offset) - nSz := noff - offset + nSz, err := encodeDNSName(rr.Name, rr.ownerMeta(), data, offset) + if err != nil { + return 0, err + } + noff := offset + nSz binary.BigEndian.PutUint16(data[noff:], uint16(rr.Type)) binary.BigEndian.PutUint16(data[noff+2:], uint16(rr.Class)) @@ -776,14 +1092,27 @@ func (rr *DNSResourceRecord) encode(data []byte, offset int, opts gopacket.Seria case DNSTypeAAAA: copy(data[noff+10:], rr.IP) case DNSTypeNS: - encodeName(rr.NS, data, noff+10) + if _, err = encodeDNSName(rr.NS, rr.rdataMeta(), data, noff+10); err != nil { + return 0, err + } case DNSTypeCNAME: - encodeName(rr.CNAME, data, noff+10) + if _, err = encodeDNSName(rr.CNAME, rr.rdataMeta(), data, noff+10); err != nil { + return 0, err + } case DNSTypePTR: - encodeName(rr.PTR, data, noff+10) + if _, err = encodeDNSName(rr.PTR, rr.rdataMeta(), data, noff+10); err != nil { + return 0, err + } case DNSTypeSOA: - noff2 := encodeName(rr.SOA.MName, data, noff+10) - noff2 = encodeName(rr.SOA.RName, data, noff2) + n1, err := encodeDNSName(rr.SOA.MName, rr.rdataMeta(), data, noff+10) + if err != nil { + return 0, err + } + n2, err := encodeDNSName(rr.SOA.RName, rr.rdata2Meta(), data, noff+10+n1) + if err != nil { + return 0, err + } + noff2 := noff + 10 + n1 + n2 binary.BigEndian.PutUint32(data[noff2:], rr.SOA.Serial) binary.BigEndian.PutUint32(data[noff2+4:], rr.SOA.Refresh) binary.BigEndian.PutUint32(data[noff2+8:], rr.SOA.Retry) @@ -791,7 +1120,9 @@ func (rr *DNSResourceRecord) encode(data []byte, offset int, opts gopacket.Seria binary.BigEndian.PutUint32(data[noff2+16:], rr.SOA.Minimum) case DNSTypeMX: binary.BigEndian.PutUint16(data[noff+10:], rr.MX.Preference) - encodeName(rr.MX.Name, data, noff+12) + if _, err = encodeDNSName(rr.MX.Name, rr.rdataMeta(), data, noff+12); err != nil { + return 0, err + } case DNSTypeTXT: noff2 := noff + 10 for _, txt := range rr.TXTs { @@ -803,7 +1134,9 @@ func (rr *DNSResourceRecord) encode(data []byte, offset int, opts gopacket.Seria binary.BigEndian.PutUint16(data[noff+10:], rr.SRV.Priority) binary.BigEndian.PutUint16(data[noff+12:], rr.SRV.Weight) binary.BigEndian.PutUint16(data[noff+14:], rr.SRV.Port) - encodeName(rr.SRV.Name, data, noff+16) + if _, err = encodeDNSName(rr.SRV.Name, rr.rdataMeta(), data, noff+16); err != nil { + return 0, err + } case DNSTypeURI: binary.BigEndian.PutUint16(data[noff+10:], rr.URI.Priority) binary.BigEndian.PutUint16(data[noff+12:], rr.URI.Weight) @@ -821,7 +1154,10 @@ func (rr *DNSResourceRecord) encode(data []byte, offset int, opts gopacket.Seria } // DataLength - dSz := recSize(rr) + dSz, err := recSize(rr) + if err != nil { + return 0, err + } binary.BigEndian.PutUint16(data[noff+8:], uint16(dSz)) if opts.FixLengths { @@ -917,30 +1253,42 @@ func (rr *DNSResourceRecord) decodeRData(data []byte, offset int, buffer *[]byte } rr.TXTs = txts case DNSTypeNS: - name, _, err := decodeName(data, offset, buffer, 1) + name, labels, _, err := decodeName(data, offset, buffer, 1) if err != nil { return err } rr.NS = name + if labels != nil { + rr.ensureNameMeta().rdata = newDNSNameMeta(name, labels) + } case DNSTypeCNAME: - name, _, err := decodeName(data, offset, buffer, 1) + name, labels, _, err := decodeName(data, offset, buffer, 1) if err != nil { return err } rr.CNAME = name + if labels != nil { + rr.ensureNameMeta().rdata = newDNSNameMeta(name, labels) + } case DNSTypePTR: - name, _, err := decodeName(data, offset, buffer, 1) + name, labels, _, err := decodeName(data, offset, buffer, 1) if err != nil { return err } rr.PTR = name + if labels != nil { + rr.ensureNameMeta().rdata = newDNSNameMeta(name, labels) + } case DNSTypeSOA: - name, endq, err := decodeName(data, offset, buffer, 1) + name, labels, endq, err := decodeName(data, offset, buffer, 1) if err != nil { return err } rr.SOA.MName = name - name, endq, err = decodeName(data, endq, buffer, 1) + if labels != nil { + rr.ensureNameMeta().rdata = newDNSNameMeta(name, labels) + } + name, labels, endq, err = decodeName(data, endq, buffer, 1) if err != nil { return err } @@ -948,6 +1296,9 @@ func (rr *DNSResourceRecord) decodeRData(data []byte, offset int, buffer *[]byte return errors.New("SOA too small") } rr.SOA.RName = name + if labels != nil { + rr.ensureNameMeta().rdata2 = newDNSNameMeta(name, labels) + } rr.SOA.Serial = binary.BigEndian.Uint32(data[endq : endq+4]) rr.SOA.Refresh = binary.BigEndian.Uint32(data[endq+4 : endq+8]) rr.SOA.Retry = binary.BigEndian.Uint32(data[endq+8 : endq+12]) @@ -958,11 +1309,14 @@ func (rr *DNSResourceRecord) decodeRData(data []byte, offset int, buffer *[]byte return errors.New("MX too small") } rr.MX.Preference = binary.BigEndian.Uint16(data[offset : offset+2]) - name, _, err := decodeName(data, offset+2, buffer, 1) + name, labels, _, err := decodeName(data, offset+2, buffer, 1) if err != nil { return err } rr.MX.Name = name + if labels != nil { + rr.ensureNameMeta().rdata = newDNSNameMeta(name, labels) + } case DNSTypeURI: if len(rr.Data) < 4 { return errors.New("URI too small") @@ -977,11 +1331,14 @@ func (rr *DNSResourceRecord) decodeRData(data []byte, offset int, buffer *[]byte rr.SRV.Priority = binary.BigEndian.Uint16(data[offset : offset+2]) rr.SRV.Weight = binary.BigEndian.Uint16(data[offset+2 : offset+4]) rr.SRV.Port = binary.BigEndian.Uint16(data[offset+4 : offset+6]) - name, _, err := decodeName(data, offset+6, buffer, 1) + name, labels, _, err := decodeName(data, offset+6, buffer, 1) if err != nil { return err } rr.SRV.Name = name + if labels != nil { + rr.ensureNameMeta().rdata = newDNSNameMeta(name, labels) + } case DNSTypeOPT: allOPT, err := decodeOPTs(data, offset) if err != nil { diff --git a/layers/dns_test.go b/layers/dns_test.go index 5d7cbad32..2455df5e8 100644 --- a/layers/dns_test.go +++ b/layers/dns_test.go @@ -22,6 +22,22 @@ func FuzzDecodeFromBytes(f *testing.F) { }) } +// FuzzDecodeSerializeDNS fuzzes the decode -> serialize round trip: any input +// that decodes without error must re-serialize without panicking. This guards +// the name encoder's buffer sizing, which sizes the output with one routine +// and writes it with another, against ever disagreeing. +func FuzzDecodeSerializeDNS(f *testing.F) { + f.Add(testPacketDNSNilRdata) + f.Fuzz(func(t *testing.T, data []byte) { + var dns DNS + if err := dns.DecodeFromBytes(data, gopacket.NilDecodeFeedback); err != nil { + return + } + buf := gopacket.NewSerializeBuffer() + _ = dns.SerializeTo(buf, gopacket.SerializeOptions{FixLengths: true}) + }) +} + // it have a layer like that: // name: xxx.com // type: CNAME @@ -466,6 +482,526 @@ func testDNSEqual(t *testing.T, exp, got *DNS) { } } +func mustSerializeDNS(t *testing.T, dns *DNS) []byte { + t.Helper() + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{FixLengths: true} + if err := gopacket.SerializeLayers(buf, opts, dns); err != nil { + t.Fatal(err) + } + return buf.Bytes() +} + +func serializeDNSError(dns *DNS) error { + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{FixLengths: true} + return gopacket.SerializeLayers(buf, opts, dns) +} + +func dnsWireName(labels ...[]byte) []byte { + var out []byte + for _, label := range labels { + out = append(out, byte(len(label))) + out = append(out, label...) + } + return append(out, 0) +} + +func TestDNSEncodeNamePresentationForm(t *testing.T) { + tests := []struct { + name string + input []byte + qtype DNSType + want []byte + mustNotSee []byte + exact bool + }{ + { + name: "plain dot separates labels", + input: []byte("foo.bar"), + qtype: DNSTypeA, + want: dnsWireName([]byte("foo"), []byte("bar")), + mustNotSee: dnsWireName([]byte("foo.bar")), + }, + { + name: "root name", + input: []byte("."), + qtype: DNSTypeNS, + want: []byte{ + 0x04, 0xd2, 0x01, 0x00, + 0x00, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, + 0x00, 0x02, + 0x00, 0x01, + }, + exact: true, + }, + { + name: "escaped dot is label data", + input: []byte(`foo\.bar`), + qtype: DNSTypeA, + want: dnsWireName([]byte("foo.bar")), + mustNotSee: dnsWireName([]byte(`foo\`), []byte("bar")), + }, + { + name: "trailing escaped dot is label data", + input: []byte(`foo\.`), + qtype: DNSTypeA, + want: dnsWireName([]byte("foo.")), + }, + { + name: "escaped backslash is label data", + input: []byte(`foo\\bar.example`), + qtype: DNSTypeA, + want: dnsWireName([]byte(`foo\bar`), []byte("example")), + }, + { + name: "unknown escape remains literal", + input: []byte(`foo\qbar.example`), + qtype: DNSTypeA, + want: dnsWireName([]byte(`foo\qbar`), []byte("example")), + }, + { + name: "decimal escape sequence", + input: []byte(`foo\065bar`), + qtype: DNSTypeA, + want: dnsWireName([]byte("fooAbar")), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dns := &DNS{ID: 1234, OpCode: DNSOpCodeQuery, RD: true} + dns.Questions = []DNSQuestion{{ + Name: tt.input, + Type: tt.qtype, + Class: DNSClassIN, + }} + + got := mustSerializeDNS(t, dns) + if tt.exact { + if !bytes.Equal(got, tt.want) { + t.Fatalf("serialized DNS = %x, want exactly %x", got, tt.want) + } + } else if !bytes.Contains(got, tt.want) { + t.Fatalf("serialized DNS did not contain name %x in %x", tt.want, got) + } + if len(tt.mustNotSee) > 0 && bytes.Contains(got, tt.mustNotSee) { + t.Fatalf("serialized DNS contained unwanted name %x in %x", tt.mustNotSee, got) + } + }) + } +} + +func TestDNSDecodeSerializePreservesLiteralDotLabel(t *testing.T) { + instance := []byte("foo.bar") + wireName := dnsWireName(instance, []byte("_googlecast"), []byte("_tcp"), []byte("local")) + msg := append([]byte{ + 0x12, 0x34, // ID + 0x84, 0x00, // response + 0x00, 0x00, // QDCOUNT + 0x00, 0x01, // ANCOUNT + 0x00, 0x00, // NSCOUNT + 0x00, 0x00, // ARCOUNT + }, wireName...) + msg = append(msg, + 0x00, 0x0c, // PTR + 0x00, 0x01, // IN + 0x00, 0x00, 0x00, 0x78, // TTL + ) + ptrTarget := dnsWireName(instance, []byte("_googlecast"), []byte("_tcp"), []byte("local")) + msg = append(msg, byte(len(ptrTarget)>>8), byte(len(ptrTarget))) + msg = append(msg, ptrTarget...) + + packet := gopacket.NewPacket(msg, LayerTypeDNS, testDecodeOptions) + if errLayer := packet.ErrorLayer(); errLayer != nil { + t.Fatal(errLayer.Error()) + } + decoded := packet.Layer(LayerTypeDNS).(*DNS) + if got := string(decoded.Answers[0].Name); got != "foo.bar._googlecast._tcp.local" { + t.Fatalf("legacy decoded owner name changed: %q", got) + } + if got := string(decoded.Answers[0].PTR); got != "foo.bar._googlecast._tcp.local" { + t.Fatalf("legacy decoded PTR name changed: %q", got) + } + + out := mustSerializeDNS(t, decoded) + if !bytes.Contains(out, wireName) { + t.Fatalf("serialized DNS did not preserve owner literal-dot label %x in %x", wireName, out) + } + if !bytes.Contains(out, ptrTarget) { + t.Fatalf("serialized DNS did not preserve PTR literal-dot label %x in %x", ptrTarget, out) + } + if bytes.Contains(out, dnsWireName([]byte("foo"), []byte("bar"), []byte("_googlecast"), []byte("_tcp"), []byte("local"))) { + t.Fatalf("serialized DNS split literal-dot label: %x", out) + } +} + +func TestDNSDecodeSerializePreservesLiteralDotLabelWithCompression(t *testing.T) { + instance := []byte("foo.bar") + qName := dnsWireName(instance, []byte("_googlecast"), []byte("_tcp"), []byte("local")) + + msg := append([]byte{ + 0x12, 0x34, // ID + 0x84, 0x00, // response + 0x00, 0x01, // QDCOUNT + 0x00, 0x01, // ANCOUNT + 0x00, 0x00, // NSCOUNT + 0x00, 0x00, // ARCOUNT + }, qName...) + msg = append(msg, + 0x00, 0x01, // A + 0x00, 0x01, // IN + ) + + instanceWire := dnsWireName(instance) + ansName := append(append([]byte(nil), instanceWire[:len(instanceWire)-1]...), 0xc0, 20) + msg = append(msg, ansName...) + msg = append(msg, + 0x00, 0x01, // A + 0x00, 0x01, // IN + 0x00, 0x00, 0x00, 0x78, // TTL + 0x00, 0x04, // DataLength + 192, 0, 2, 1, // IP + ) + + packet := gopacket.NewPacket(msg, LayerTypeDNS, testDecodeOptions) + if errLayer := packet.ErrorLayer(); errLayer != nil { + t.Fatal(errLayer.Error()) + } + decoded := packet.Layer(LayerTypeDNS).(*DNS) + if got := string(decoded.Questions[0].Name); got != "foo.bar._googlecast._tcp.local" { + t.Fatalf("decoded question name changed: %q", got) + } + if got := string(decoded.Answers[0].Name); got != "foo.bar._googlecast._tcp.local" { + t.Fatalf("decoded answer name changed: %q", got) + } + + out := mustSerializeDNS(t, decoded) + wantUncompressed := dnsWireName(instance, []byte("_googlecast"), []byte("_tcp"), []byte("local")) + count := bytes.Count(out, wantUncompressed) + if count != 2 { + t.Fatalf("expected 2 copies of uncompressed name %x in serialized bytes %x, got %d", wantUncompressed, out, count) + } + if bytes.Contains(out, dnsWireName([]byte("foo"), []byte("bar"), []byte("_googlecast"), []byte("_tcp"), []byte("local"))) { + t.Fatalf("serialized DNS split literal-dot label in compression test: %x", out) + } +} + +func TestDNSDecodeMutateRDataNameKeepsOwnerLabels(t *testing.T) { + owner := dnsWireName([]byte("foo.bar"), []byte("local")) + rdata := dnsWireName([]byte("baz.qux"), []byte("local")) + msg := append([]byte{ + 0x12, 0x34, // ID + 0x84, 0x00, // response + 0x00, 0x00, // QDCOUNT + 0x00, 0x01, // ANCOUNT + 0x00, 0x00, // NSCOUNT + 0x00, 0x00, // ARCOUNT + }, owner...) + msg = append(msg, + 0x00, 0x05, // CNAME + 0x00, 0x01, // IN + 0x00, 0x00, 0x00, 0x78, // TTL + byte(len(rdata)>>8), byte(len(rdata)), + ) + msg = append(msg, rdata...) + + packet := gopacket.NewPacket(msg, LayerTypeDNS, testDecodeOptions) + if errLayer := packet.ErrorLayer(); errLayer != nil { + t.Fatal(errLayer.Error()) + } + decoded := packet.Layer(LayerTypeDNS).(*DNS) + decoded.Answers[0].CNAME = []byte("new.target.local") + + out := mustSerializeDNS(t, decoded) + if !bytes.Contains(out, owner) { + t.Fatalf("serialized DNS did not preserve unchanged owner literal-dot label %x in %x", owner, out) + } + wantCNAME := dnsWireName([]byte("new"), []byte("target"), []byte("local")) + if !bytes.Contains(out, wantCNAME) { + t.Fatalf("changed CNAME did not use presentation fallback %x in %x", wantCNAME, out) + } +} + +func TestDNSDecodeMutateNameUsesPresentationFallback(t *testing.T) { + wireName := dnsWireName([]byte("foo.bar"), []byte("_tcp"), []byte("local")) + msg := append([]byte{ + 0x12, 0x34, + 0x01, 0x00, + 0x00, 0x01, + 0x00, 0x00, + 0x00, 0x00, + 0x00, 0x00, + }, wireName...) + msg = append(msg, + 0x00, 0x01, + 0x00, 0x01, + ) + + packet := gopacket.NewPacket(msg, LayerTypeDNS, testDecodeOptions) + if errLayer := packet.ErrorLayer(); errLayer != nil { + t.Fatal(errLayer.Error()) + } + decoded := packet.Layer(LayerTypeDNS).(*DNS) + decoded.Questions[0].Name = []byte("bar.baz._tcp.local") + + out := mustSerializeDNS(t, decoded) + wantName := dnsWireName([]byte("bar"), []byte("baz"), []byte("_tcp"), []byte("local")) + if !bytes.Contains(out, wantName) { + t.Fatalf("changed public name did not use presentation fallback %x in %x", wantName, out) + } + if bytes.Contains(out, wireName) { + t.Fatalf("stale decoded labels were used after public name mutation: %x", out) + } +} + +func TestDNSEncodeNameValidation(t *testing.T) { + var name []byte + for i := 0; i < 4; i++ { + if i > 0 { + name = append(name, '.') + } + name = append(name, bytes.Repeat([]byte{'a'}, 63)...) + } + + tests := []struct { + name string + input []byte + wantErr bool + }{ + {name: "max length label", input: bytes.Repeat([]byte{'a'}, 63)}, + {name: "long label", input: bytes.Repeat([]byte{'a'}, 64), wantErr: true}, + {name: "long name", input: name, wantErr: true}, + {name: "invalid decimal escape", input: []byte(`foo\999`), wantErr: true}, + {name: "truncated decimal escape", input: []byte(`foo\12`), wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dns := &DNS{ID: 1234, OpCode: DNSOpCodeQuery, RD: true} + dns.Questions = []DNSQuestion{{ + Name: tt.input, + Type: DNSTypeA, + Class: DNSClassIN, + }} + + err := serializeDNSError(dns) + if tt.wantErr && err == nil { + t.Fatal("expected error") + } + if !tt.wantErr && err != nil { + t.Fatalf("expected no error, got %v", err) + } + }) + } +} + +func TestDNSDecodeSerializePreservesLiteralDotRDATA(t *testing.T) { + owner := dnsWireName([]byte("example"), []byte("local")) + literalName := dnsWireName([]byte("foo.bar"), []byte("example"), []byte("local")) + splitName := dnsWireName([]byte("foo"), []byte("bar"), []byte("example"), []byte("local")) + + buildMessage := func(dnsType DNSType, rdata []byte) []byte { + msg := append([]byte{ + 0x12, 0x34, + 0x84, 0x00, + 0x00, 0x00, + 0x00, 0x01, + 0x00, 0x00, + 0x00, 0x00, + }, owner...) + msg = append(msg, + byte(uint16(dnsType)>>8), byte(uint16(dnsType)), + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x78, + byte(len(rdata)>>8), byte(len(rdata)), + ) + msg = append(msg, rdata...) + return msg + } + + soaRData := append([]byte{}, literalName...) + soaRData = append(soaRData, literalName...) + soaRData = append(soaRData, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x05, + ) + + tests := []struct { + name string + dnsType DNSType + rdata []byte + }{ + {name: "NS", dnsType: DNSTypeNS, rdata: literalName}, + {name: "CNAME", dnsType: DNSTypeCNAME, rdata: literalName}, + {name: "PTR", dnsType: DNSTypePTR, rdata: literalName}, + {name: "MX", dnsType: DNSTypeMX, rdata: append([]byte{0x00, 0x0a}, literalName...)}, + {name: "SRV", dnsType: DNSTypeSRV, rdata: append([]byte{0x00, 0x01, 0x00, 0x02, 0x1f, 0x90}, literalName...)}, + {name: "SOA", dnsType: DNSTypeSOA, rdata: soaRData}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + packet := gopacket.NewPacket(buildMessage(tt.dnsType, tt.rdata), LayerTypeDNS, testDecodeOptions) + if errLayer := packet.ErrorLayer(); errLayer != nil { + t.Fatal(errLayer.Error()) + } + decoded := packet.Layer(LayerTypeDNS).(*DNS) + out := mustSerializeDNS(t, decoded) + if !bytes.Contains(out, literalName) { + t.Fatalf("serialized DNS did not preserve literal-dot RDATA name %x in %x", literalName, out) + } + if bytes.Contains(out, splitName) { + t.Fatalf("serialized DNS split literal-dot RDATA name: %x", out) + } + }) + } +} + +func TestDNSDecodeSerializePreservesLabelsAfterInputReuse(t *testing.T) { + // Regression: preserved labels must own their bytes, not alias the caller's + // input buffer. gopacket's DecodingLayerParser and zero-copy readers reuse + // the input slice across packets; decoded names are copied into the layer's + // own buffer, so the preserved label metadata must be copied too. A label + // after the first literal-dot label exercises the incremental append path. + owner := dnsWireName([]byte("plain"), []byte("foo.bar"), []byte("tail")) + msg := append([]byte{ + 0x12, 0x34, // ID + 0x84, 0x00, // response + 0x00, 0x00, // QDCOUNT + 0x00, 0x01, // ANCOUNT + 0x00, 0x00, // NSCOUNT + 0x00, 0x00, // ARCOUNT + }, owner...) + msg = append(msg, + 0x00, 0x01, // A + 0x00, 0x01, // IN + 0x00, 0x00, 0x00, 0x78, // TTL + 0x00, 0x04, // DataLength + 192, 0, 2, 1, // IP + ) + + var decoded DNS + if err := decoded.DecodeFromBytes(msg, gopacket.NilDecodeFeedback); err != nil { + t.Fatal(err) + } + if got := string(decoded.Answers[0].Name); got != "plain.foo.bar.tail" { + t.Fatalf("decoded owner name changed: %q", got) + } + + // Simulate buffer reuse: scribble over the input region holding the name. + for i := 12; i < 12+len(owner); i++ { + msg[i] = 0xff + } + + out := mustSerializeDNS(t, &decoded) + if !bytes.Contains(out, owner) { + t.Fatalf("preserved labels aliased the input buffer; serialized name corrupted: %x", out) + } +} + +func TestDNSDecodeReuseClearsStaleNameMeta(t *testing.T) { + // Decoding into a reused DNS layer (the DecodingLayerParser pattern) must not + // leak preserved label metadata from a previous packet. The decode loop + // overwrites each record slot with a zero DNSResourceRecord, which resets the + // lazily-allocated names pointer; this test guards that invariant. + header := []byte{0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00} + aRecord := []byte{0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x04, 192, 0, 2, 1} + + withDot := append(append([]byte{}, header...), dnsWireName([]byte("foo.bar"), []byte("local"))...) + withDot = append(withDot, aRecord...) + plain := append(append([]byte{}, header...), dnsWireName([]byte("plain"), []byte("local"))...) + plain = append(plain, aRecord...) + + var d DNS + if err := d.DecodeFromBytes(withDot, gopacket.NilDecodeFeedback); err != nil { + t.Fatal(err) + } + if d.Answers[0].names == nil { + t.Fatal("expected preserved label metadata for a literal-dot owner name") + } + + // Reuse the same layer for a name that needs no preservation. + if err := d.DecodeFromBytes(plain, gopacket.NilDecodeFeedback); err != nil { + t.Fatal(err) + } + if d.Answers[0].names != nil { + t.Fatalf("stale name metadata leaked across decode reuse: %+v", d.Answers[0].names) + } + out := mustSerializeDNS(t, &d) + if !bytes.Contains(out, dnsWireName([]byte("plain"), []byte("local"))) { + t.Fatalf("reused decode did not serialize the plain name as split labels: %x", out) + } +} + +func TestDNSDecodePreservedLabelsDoNotAliasDecodeBuffer(t *testing.T) { + // Internal invariant: preserved label metadata must not alias the layer's + // decode buffer, so a decoded name survives buffer/layer reuse. The + // compression path reconstructs the pointed-to labels by splitting the + // resolved name; those slices must be copied, not left pointing into the + // decode buffer. + // + // This is a white-box check. It is not observable through serialization + // alone: the public Name field also aliases the decode buffer, and the + // encode path falls back to presentation parsing whenever Name differs from + // its decode snapshot, so the labels are never read once stale. The black-box + // input-reuse and compression tests therefore cannot catch this; the labels + // being independent metadata is what makes it matter. + header := []byte{ + 0x12, 0x34, + 0x84, 0x00, + 0x00, 0x01, // QDCOUNT + 0x00, 0x01, // ANCOUNT + 0x00, 0x00, + 0x00, 0x00, + } + // Question name "aaa.bbb.local" sits at offset 12 and is the pointer target. + msg := append(append([]byte{}, header...), dnsWireName([]byte("aaa"), []byte("bbb"), []byte("local"))...) + msg = append(msg, 0x00, 0x01, 0x00, 0x01) // QTYPE A, QCLASS IN + // Answer name: literal-dot label "foo.bar" followed by a pointer to offset 12. + msg = append(msg, 0x07) + msg = append(msg, []byte("foo.bar")...) + msg = append(msg, 0xc0, 0x0c) + msg = append(msg, + 0x00, 0x01, 0x00, 0x01, // A, IN + 0x00, 0x00, 0x00, 0x78, // TTL + 0x00, 0x04, // RDLENGTH + 192, 0, 2, 1, + ) + + var d DNS + if err := d.DecodeFromBytes(msg, gopacket.NilDecodeFeedback); err != nil { + t.Fatal(err) + } + if got := string(d.Answers[0].Name); got != "foo.bar.aaa.bbb.local" { + t.Fatalf("decoded name = %q, want foo.bar.aaa.bbb.local", got) + } + want := []string{"foo.bar", "aaa", "bbb", "local"} + labels := d.Answers[0].names.name.labels + if len(labels) != len(want) { + t.Fatalf("decoded %d labels, want %d: %q", len(labels), len(want), labels) + } + for i := range want { + if string(labels[i]) != want[i] { + t.Fatalf("label %d = %q, want %q", i, labels[i], want[i]) + } + } + // Corrupt the decode buffer in place; preserved labels must be unaffected. + for i := range d.buffer { + d.buffer[i] = '?' + } + for i := range want { + if string(labels[i]) != want[i] { + t.Fatalf("preserved label %d aliased the decode buffer: got %q after reuse, want %q", i, labels[i], want[i]) + } + } +} + func TestDNSEncodeQuery(t *testing.T) { dns := &DNS{ID: 1234, OpCode: DNSOpCodeQuery, RD: true} dns.Questions = append(dns.Questions,