diff --git a/conntrack_linux.go b/conntrack_linux.go index 771fa703..ff20869b 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -366,6 +366,8 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { // // // + // + // // // CTA_TUPLE_ORIG @@ -392,6 +394,14 @@ func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { ctTimeout := nl.NewRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut)) payload = append(payload, ctTupleOrig, ctTupleReply, ctMark, ctTimeout) + // Labels: nil => do not send; 16 zero bytes => clear all labels. + if s.Labels != nil { + if len(s.Labels) != 16 { + return nil, fmt.Errorf("conntrack CTA_LABELS must be 16 bytes, got %d", len(s.Labels)) + } + ctLabels := nl.NewRtAttr(nl.CTA_LABELS, s.Labels) + payload = append(payload, ctLabels) + } if s.ProtoInfo != nil { switch p := s.ProtoInfo.(type) { diff --git a/conntrack_test.go b/conntrack_test.go index fd036b8d..48e5c4a1 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -4,6 +4,7 @@ package netlink import ( + "bytes" "encoding/binary" "fmt" "net" @@ -72,10 +73,16 @@ func ensureCtHooksInThisNS(t *testing.T) func() { var addedInput, addedOutput bool if ipt(false, "-C", "INPUT", "-m", "conntrack", "--ctstate", "NEW,ESTABLISHED", "-j", "ACCEPT") != nil { ipt(true, "-I", "INPUT", "-m", "conntrack", "--ctstate", "NEW,ESTABLISHED", "-j", "ACCEPT") + // Add a rule to set conntrack label to allocate the label space + // https://lore.kernel.org/netfilter-devel/aPdkVOTuUElaFKZZ@strlen.de/ + ipt(true, "-I", "INPUT", "-m", "connlabel", "--set", "--label", "1") addedInput = true } if ipt(false, "-C", "OUTPUT", "-m", "conntrack", "--ctstate", "ESTABLISHED", "-j", "ACCEPT") != nil { ipt(true, "-I", "OUTPUT", "-m", "conntrack", "--ctstate", "ESTABLISHED", "-j", "ACCEPT") + // Add a rule to set conntrack label to allocate the label space + // https://lore.kernel.org/netfilter-devel/aPdkVOTuUElaFKZZ@strlen.de/ + ipt(true, "-I", "OUTPUT", "-m", "connlabel", "--set", "--label", "1") addedOutput = true } return func() { @@ -98,6 +105,12 @@ func ensureCtHooksInThisNS(t *testing.T) func() { _ = exec.Command("nft", "add", "chain", "inet", "ct_test", "output", "{", "type", "filter", "hook", "output", "priority", "0", ";", "ct", "state", "established", "accept", "}").Run() + // Add a rule to set conntrack label to allocate the label space + // https://lore.kernel.org/netfilter-devel/aPdkVOTuUElaFKZZ@strlen.de/ + _ = exec.Command("nft", "add", "rule", "inet", "ct_test", "output", + "ct", "label", "set", "1").Run() + _ = exec.Command("nft", "add", "rule", "inet", "ct_test", "input", + "ct", "label", "set", "1").Run() return func() { _ = exec.Command("nft", "delete", "table", "inet", "ct_test").Run() } @@ -1498,6 +1511,138 @@ func TestConntrackCreateV6(t *testing.T) { checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) } +// TestConntrackLabels test the conntrack table labels +// Creates some flows and then checks the labels associated +func TestConntrackLabels(t *testing.T) { + skipUnlessRoot(t) + t.Cleanup(setUpNetlinkTestWithKModule(t, "nf_conntrack")) + t.Cleanup(setUpNetlinkTestWithKModule(t, "nf_conntrack_netlink")) + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + t.Cleanup(setUpNetlinkTestWithKModule(t, "nf_conntrack_ipv4")) + } + // Creates a new namespace and bring up the loopback interface + origns, ns, h := nsCreateAndEnter(t) + defer netns.Set(*origns) + defer origns.Close() + defer ns.Close() + defer runtime.UnlockOSThread() + + flow := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{234, 234, 234, 234}, + DstIP: net.IP{123, 123, 123, 123}, + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{123, 123, 123, 123}, + DstIP: net.IP{234, 234, 234, 234}, + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + // No point checking equivalence of timeout, but value must + // be reasonable to allow for a potentially slow subsequent read. + TimeOut: 100, + Mark: 12, + Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_SYN_SENT2, + }, + } + + err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow) + if err == nil { + t.Fatalf("expected an error to occur when trying to update a non-existant conntrack: %+v", flow) + } + + err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow) + if err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter: unix.IPPROTO_TCP, + } + + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + + if match == nil { + t.Fatalf("Didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels) + } + checkFlowsEqual(t, &flow, match) + checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) + + // Change the conntrack and update the kernel entry. + flow.Mark = 10 + flow.Labels = make([]byte, 16) // zero labels + flow.ProtoInfo = &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + } + err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow) + if err != nil { + t.Fatalf("failed to update conntrack with new mark: %s", err) + } + + // Look for updated conntrack. + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following successful update: %s", err) + } + + var updatedMatch *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + updatedMatch = f + break + } + } + if updatedMatch == nil { + t.Fatalf("Didn't find any matching conntrack entries for updated flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("Found entry in conntrack table matching updated flow: %+v labels=%+v", updatedMatch, updatedMatch.Labels) + } + + // To clear the labels we send an empty slice, but when reading back + // from the kernel we get a nil slice. + flow.Labels = nil + checkFlowsEqual(t, &flow, updatedMatch) + checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo) + // Switch back to the original namespace + netns.Set(*origns) +} + // TestConntrackFlowToNlData generates a serialized representation of a // ConntrackFlow and runs the resulting bytes back through `parseRawData` to validate. func TestConntrackFlowToNlData(t *testing.T) { @@ -1518,6 +1663,7 @@ func TestConntrackFlowToNlData(t *testing.T) { Protocol: unix.IPPROTO_TCP, }, Mark: 5, + Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, TimeOut: 10, ProtoInfo: &ProtoInfoTCP{ State: nl.TCP_CONNTRACK_ESTABLISHED, @@ -1540,6 +1686,7 @@ func TestConntrackFlowToNlData(t *testing.T) { Protocol: unix.IPPROTO_TCP, }, Mark: 5, + Labels: []byte{0, 0, 0, 0, 3, 4, 61, 141, 207, 170, 2, 0, 0, 0, 0, 0}, TimeOut: 10, ProtoInfo: &ProtoInfoTCP{ State: nl.TCP_CONNTRACK_ESTABLISHED, @@ -1596,6 +1743,11 @@ func checkFlowsEqual(t *testing.T, f1, f2 *ConntrackFlow) { t.Logf("Reverse tuples mismatch. Tuple1 reverse flow: %+v, Tuple2 reverse flow: %+v.\n", f1.Reverse, f2.Reverse) t.Fail() } + + if !bytes.Equal(f1.Labels, f2.Labels) { + t.Logf("Conntrack flow Labels differ. Tuple1: %+v, Tuple2: %+v.\n", f1.Labels, f2.Labels) + t.Fail() + } } func checkProtoInfosEqual(t *testing.T, p1, p2 ProtoInfo) {