Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions conntrack_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) {
// <BEuint64>
// <len, CTA_TIMEOUT>
// <BEuint64>
// <len, CTA_LABELS>
// <binary data>
// <len, NLA_F_NESTED|CTA_PROTOINFO>

// CTA_TUPLE_ORIG
Expand All @@ -392,6 +394,14 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) {
ctTimeout := nl.NewRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut))

payload = append(payload, ctTupleOrig, ctTupleReply, ctMark, ctTimeout)
// Labels: nil => do not send; 16 zero bytes => clear all labels.
if s.Labels != nil {
if len(s.Labels) != 16 {
return nil, fmt.Errorf("conntrack CTA_LABELS must be 16 bytes, got %d", len(s.Labels))
}
ctLabels := nl.NewRtAttr(nl.CTA_LABELS, s.Labels)
payload = append(payload, ctLabels)
}

if s.ProtoInfo != nil {
switch p := s.ProtoInfo.(type) {
Expand Down
152 changes: 152 additions & 0 deletions conntrack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package netlink

import (
"bytes"
"encoding/binary"
"fmt"
"net"
Expand Down Expand Up @@ -72,10 +73,16 @@ func ensureCtHooksInThisNS(t *testing.T) func() {
var addedInput, addedOutput bool
if ipt(false, "-C", "INPUT", "-m", "conntrack", "--ctstate", "NEW,ESTABLISHED", "-j", "ACCEPT") != nil {
ipt(true, "-I", "INPUT", "-m", "conntrack", "--ctstate", "NEW,ESTABLISHED", "-j", "ACCEPT")
// Add a rule to set conntrack label to allocate the label space
// https://lore.kernel.org/netfilter-devel/[email protected]/
ipt(true, "-I", "INPUT", "-m", "connlabel", "--set", "--label", "1")
addedInput = true
}
if ipt(false, "-C", "OUTPUT", "-m", "conntrack", "--ctstate", "ESTABLISHED", "-j", "ACCEPT") != nil {
ipt(true, "-I", "OUTPUT", "-m", "conntrack", "--ctstate", "ESTABLISHED", "-j", "ACCEPT")
// Add a rule to set conntrack label to allocate the label space
// https://lore.kernel.org/netfilter-devel/[email protected]/
ipt(true, "-I", "OUTPUT", "-m", "connlabel", "--set", "--label", "1")
addedOutput = true
}
return func() {
Expand All @@ -98,6 +105,12 @@ func ensureCtHooksInThisNS(t *testing.T) func() {
_ = exec.Command("nft", "add", "chain", "inet", "ct_test", "output",
"{", "type", "filter", "hook", "output", "priority", "0", ";",
"ct", "state", "established", "accept", "}").Run()
// Add a rule to set conntrack label to allocate the label space
// https://lore.kernel.org/netfilter-devel/[email protected]/
_ = exec.Command("nft", "add", "rule", "inet", "ct_test", "output",
"ct", "label", "set", "1").Run()
_ = exec.Command("nft", "add", "rule", "inet", "ct_test", "input",
"ct", "label", "set", "1").Run()
return func() {
_ = exec.Command("nft", "delete", "table", "inet", "ct_test").Run()
}
Expand Down Expand Up @@ -1498,6 +1511,138 @@ func TestConntrackCreateV6(t *testing.T) {
checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo)
}

// TestConntrackLabels test the conntrack table labels
// Creates some flows and then checks the labels associated
func TestConntrackLabels(t *testing.T) {
skipUnlessRoot(t)
t.Cleanup(setUpNetlinkTestWithKModule(t, "nf_conntrack"))
t.Cleanup(setUpNetlinkTestWithKModule(t, "nf_conntrack_netlink"))
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
}
// conntrack l3proto was unified since 4.19
// https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f
if k < 4 || k == 4 && m < 19 {
t.Cleanup(setUpNetlinkTestWithKModule(t, "nf_conntrack_ipv4"))
}
// Creates a new namespace and bring up the loopback interface
origns, ns, h := nsCreateAndEnter(t)
defer netns.Set(*origns)
defer origns.Close()
defer ns.Close()
defer runtime.UnlockOSThread()

flow := ConntrackFlow{
FamilyType: FAMILY_V4,
Forward: IPTuple{
SrcIP: net.IP{234, 234, 234, 234},
DstIP: net.IP{123, 123, 123, 123},
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.IP{123, 123, 123, 123},
DstIP: net.IP{234, 234, 234, 234},
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
// No point checking equivalence of timeout, but value must
// be reasonable to allow for a potentially slow subsequent read.
TimeOut: 100,
Mark: 12,
Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0},
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_SYN_SENT2,
},
}

err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow)
if err == nil {
t.Fatalf("expected an error to occur when trying to update a non-existant conntrack: %+v", flow)
}

err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow)
if err != nil {
t.Fatalf("failed to insert conntrack: %s", err)
}

flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to list conntracks following successful insert: %s", err)
}

filter := ConntrackFilter{
ipNetFilter: map[ConntrackFilterType]*net.IPNet{
ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP),
ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP),
ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP),
ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP),
},
portFilter: map[ConntrackFilterType]uint16{
ConntrackOrigSrcPort: flow.Forward.SrcPort,
ConntrackOrigDstPort: flow.Forward.DstPort,
},
protoFilter: unix.IPPROTO_TCP,
}

var match *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
match = f
break
}
}

if match == nil {
t.Fatalf("Didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels)
}
checkFlowsEqual(t, &flow, match)
checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo)

// Change the conntrack and update the kernel entry.
flow.Mark = 10
flow.Labels = make([]byte, 16) // zero labels
flow.ProtoInfo = &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
}
err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow)
if err != nil {
t.Fatalf("failed to update conntrack with new mark: %s", err)
}

// Look for updated conntrack.
flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to list conntracks following successful update: %s", err)
}

var updatedMatch *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
updatedMatch = f
break
}
}
if updatedMatch == nil {
t.Fatalf("Didn't find any matching conntrack entries for updated flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("Found entry in conntrack table matching updated flow: %+v labels=%+v", updatedMatch, updatedMatch.Labels)
}

// To clear the labels we send an empty slice, but when reading back
// from the kernel we get a nil slice.
flow.Labels = nil
checkFlowsEqual(t, &flow, updatedMatch)
checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo)
// Switch back to the original namespace
netns.Set(*origns)
}

// TestConntrackFlowToNlData generates a serialized representation of a
// ConntrackFlow and runs the resulting bytes back through `parseRawData` to validate.
func TestConntrackFlowToNlData(t *testing.T) {
Expand All @@ -1518,6 +1663,7 @@ func TestConntrackFlowToNlData(t *testing.T) {
Protocol: unix.IPPROTO_TCP,
},
Mark: 5,
Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0},
TimeOut: 10,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
Expand All @@ -1540,6 +1686,7 @@ func TestConntrackFlowToNlData(t *testing.T) {
Protocol: unix.IPPROTO_TCP,
},
Mark: 5,
Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0},
TimeOut: 10,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
Expand Down Expand Up @@ -1596,6 +1743,11 @@ func checkFlowsEqual(t *testing.T, f1, f2 *ConntrackFlow) {
t.Logf("Reverse tuples mismatch. Tuple1 reverse flow: %+v, Tuple2 reverse flow: %+v.\n", f1.Reverse, f2.Reverse)
t.Fail()
}

if !bytes.Equal(f1.Labels, f2.Labels) {
t.Logf("Conntrack flow Labels differ. Tuple1: %+v, Tuple2: %+v.\n", f1.Labels, f2.Labels)
t.Fail()
}
}

func checkProtoInfosEqual(t *testing.T, p1, p2 ProtoInfo) {
Expand Down