From df7a6b658d4cecc2230757962feebe8b0609ae17 Mon Sep 17 00:00:00 2001 From: Alexander Cueva Date: Wed, 9 Jul 2025 14:38:05 -0700 Subject: [PATCH] Update netfilter protocol address family checks. Netfilter messages return different error values when validating address families depending on the message type being received. Updated functionality to reflect that. PiperOrigin-RevId: 781215895 --- pkg/abi/linux/netlink.go | 32 +- pkg/abi/linux/nf_tables.go | 17 +- pkg/sentry/socket/netlink/netfilter/BUILD | 1 + .../socket/netlink/netfilter/protocol.go | 163 +++- pkg/sentry/socket/netlink/nlmsg/BUILD | 6 +- pkg/sentry/socket/netlink/nlmsg/message.go | 15 + pkg/tcpip/nftables/BUILD | 1 + pkg/tcpip/nftables/nftables.go | 202 +++- pkg/tcpip/nftables/nftables_types.go | 58 +- pkg/tcpip/stack/nftables_types.go | 9 +- test/syscalls/linux/BUILD | 4 + .../linux/socket_netlink_netfilter.cc | 898 ++++++++++++------ .../linux/socket_netlink_netfilter_util.cc | 159 ++++ .../linux/socket_netlink_netfilter_util.h | 76 ++ test/syscalls/linux/socket_netlink_util.cc | 4 +- 15 files changed, 1315 insertions(+), 330 deletions(-) diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index 232fee67e6..aad50319e2 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -66,21 +66,35 @@ type NetlinkMessageHeader struct { // NetlinkMessageHeaderSize is the size of NetlinkMessageHeader. const NetlinkMessageHeaderSize = 16 -// Netlink message header flags, from uapi/linux/netlink.h. +// Netlink message header flag values, from uapi/linux/netlink.h. const ( NLM_F_REQUEST = 0x1 NLM_F_MULTI = 0x2 NLM_F_ACK = 0x4 NLM_F_ECHO = 0x8 NLM_F_DUMP_INTR = 0x10 - NLM_F_ROOT = 0x100 - NLM_F_MATCH = 0x200 - NLM_F_ATOMIC = 0x400 - NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH - NLM_F_REPLACE = 0x100 - NLM_F_EXCL = 0x200 - NLM_F_CREATE = 0x400 - NLM_F_APPEND = 0x800 +) + +// Netlink message header flags for GET requests, from uapi/linux/netlink.h. +const ( + NLM_F_ROOT = 0x100 + NLM_F_MATCH = 0x200 + NLM_F_ATOMIC = 0x400 + NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH +) + +// Netlink message header flags for NEW requests, from uapi/linux/netlink.h. +const ( + NLM_F_REPLACE = 0x100 + NLM_F_EXCL = 0x200 + NLM_F_CREATE = 0x400 + NLM_F_APPEND = 0x800 +) + +// Netlink message header flags for DELETE requests, from uapi/linux/netlink.h. +const ( + NLM_F_NONREC = 0x100 + NLM_F_BULK = 0x200 ) // Standard netlink message types, from uapi/linux/netlink.h. diff --git a/pkg/abi/linux/nf_tables.go b/pkg/abi/linux/nf_tables.go index a18d5fd549..eee1e6f1c0 100644 --- a/pkg/abi/linux/nf_tables.go +++ b/pkg/abi/linux/nf_tables.go @@ -16,6 +16,18 @@ package linux // This file contains constants required to support nf_tables. +// Name length constants for nf_table structures. These correspond to values in +// include/uapi/linux/netfilter/nf_tables.h. +const ( + NFT_NAME_MAXLEN = 256 + NFT_TABLE_MAXNAMELEN = NFT_NAME_MAXLEN + NFT_CHAIN_MAXNAMELEN = NFT_NAME_MAXLEN + NFT_SET_MAXNAMELEN = NFT_NAME_MAXLEN + NFT_OBJ_MAXNAMELEN = NFT_NAME_MAXLEN + NFT_USERDATA_MAXLEN = 256 + NFT_OSF_MAXGENRELEN = 16 +) + // 16-byte Registers that can be used to maintain state for rules. // These correspond to values in include/uapi/linux/netfilter/nf_tables.h. const ( @@ -127,7 +139,10 @@ const ( // NfTableFlags represents table flags that can be set for a table, namely dormant. // These correspond to values in include/uapi/linux/netfilter/nf_tables.h. const ( - NFT_TABLE_F_DORMANT = 0x1 + NFT_TABLE_F_DORMANT uint32 = 0x1 + NFT_TABLE_F_OWNER = 0x2 + NFT_TABLE_F_PERSIST = 0x4 + NFT_TABLE_F_MASK = NFT_TABLE_F_DORMANT | NFT_TABLE_F_OWNER | NFT_TABLE_F_PERSIST ) // NfTableAttributes represents the netfilter table attributes. diff --git a/pkg/sentry/socket/netlink/netfilter/BUILD b/pkg/sentry/socket/netlink/netfilter/BUILD index 320496c7ac..28b355cff5 100644 --- a/pkg/sentry/socket/netlink/netfilter/BUILD +++ b/pkg/sentry/socket/netlink/netfilter/BUILD @@ -13,6 +13,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/socket/netlink", diff --git a/pkg/sentry/socket/netlink/netfilter/protocol.go b/pkg/sentry/socket/netlink/netfilter/protocol.go index 86404cee62..ce8b6c3c45 100644 --- a/pkg/sentry/socket/netlink/netfilter/protocol.go +++ b/pkg/sentry/socket/netlink/netfilter/protocol.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" @@ -88,12 +89,11 @@ func (p *Protocol) ProcessMessage(ctx context.Context, s *netlink.Socket, msg *n return syserr.ErrInvalidArgument } - // Nftables functions error check the address family value. - family := stack.AddressFamily(nfGenMsg.Family) + family := nftables.AFtoNetlinkAF(nfGenMsg.Family) // TODO: b/421437663 - Match the message type and call the appropriate Nftables function. switch msgType { case linux.NFT_MSG_NEWTABLE: - if err := p.newTable(nft, attrs, family, hdr.Flags); err != nil { + if err := p.newTable(nft, attrs, family, hdr.Flags, ms); err != nil { log.Debugf("Nftables new table error: %s", err) return err.GetError() } @@ -104,6 +104,12 @@ func (p *Protocol) ProcessMessage(ctx context.Context, s *netlink.Socket, msg *n return err.GetError() } return nil + case linux.NFT_MSG_DELTABLE, linux.NFT_MSG_DESTROYTABLE: + if err := p.deleteTable(nft, attrs, family, hdr, msgType, ms); err != nil { + log.Debugf("Nftables delete table error: %s", err) + return err.GetError() + } + return nil default: log.Debugf("Unsupported message type: %d", msgType) return syserr.ErrNotSupported @@ -111,7 +117,11 @@ func (p *Protocol) ProcessMessage(ctx context.Context, s *netlink.Socket, msg *n } // newTable creates a new table for the given family. -func (p *Protocol) newTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, flags uint16) *syserr.AnnotatedError { +func (p *Protocol) newTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, flags uint16, ms *nlmsg.MessageSet) *syserr.AnnotatedError { + if family == stack.NumAFs { + return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Address family is not supported")) + } + // TODO: b/421437663 - Handle the case where the table name is set to empty string. // The table name is required. tabNameBytes, ok := attrs[linux.NFTA_TABLE_NAME] @@ -119,13 +129,7 @@ func (p *Protocol) newTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.Bytes return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Table name attribute is malformed or not found")) } - var dormant bool - if dbytes, ok := attrs[linux.NFTA_TABLE_FLAGS]; ok { - dflag, _ := dbytes.Uint32() - dormant = (dflag & linux.NFT_TABLE_F_DORMANT) == linux.NFT_TABLE_F_DORMANT - } - - tab, err := nft.GetTable(family, tabNameBytes.String()) + tab, err := nft.GetTable(family, tabNameBytes.String(), uint32(ms.PortID)) if err != nil && err.GetError() != syserr.ErrNoFileOrDir { return err } @@ -133,20 +137,77 @@ func (p *Protocol) newTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.Bytes // If a table already exists, only update its dormant flags if NLM_F_EXCL and NLM_F_REPLACE // are not set. From net/netfilter/nf_tables_api.c:nf_tables_newtable:nf_tables_updtable if tab != nil { - if flags&linux.NLM_F_EXCL == linux.NLM_F_EXCL { - return syserr.NewAnnotatedError(syserr.ErrExists, fmt.Sprintf("Nftables: Table with name: %s already exists", tabNameBytes.String())) + if flags&linux.NLM_F_EXCL != 0 { + return syserr.NewAnnotatedError(syserr.ErrExists, fmt.Sprintf("Nftables: Table with name: %s already exists", tab.GetName())) } - if flags&linux.NLM_F_REPLACE == linux.NLM_F_REPLACE { - return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Table with name: %s already exists and NLM_F_REPLACE is not supported", tabNameBytes.String())) + if flags&linux.NLM_F_REPLACE != 0 { + return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Table with name: %s already exists and NLM_F_REPLACE is not supported", tab.GetName())) } - } else { - tab, err = nft.CreateTable(family, tabNameBytes.String()) - if err != nil { + + return p.updateTable(nft, tab, attrs, family, ms) + } + + // TODO: b/421437663 - Support additional user-specified table flags. + var attrFlags uint32 = 0 + if uflags, ok := attrs[linux.NFTA_TABLE_FLAGS]; ok { + attrFlags, _ = uflags.Uint32() + // Flags sent through the NFTA_TABLE_FLAGS attribute are of type uint32 + // but should only have user flags set. This check needs to be done before table creation. + if attrFlags & ^uint32(linux.NFT_TABLE_F_MASK) != 0 { + return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Table flags set are not supported")) + } + } + + tab, err = nft.CreateTable(family, tabNameBytes.String()) + if err != nil { + return err + } + + if udata, ok := attrs[linux.NFTA_TABLE_USERDATA]; ok { + tab.SetUserData(udata) + } + + // Flags should only be assigned after we have successfully created the table. + dormant := (attrFlags & uint32(linux.NFT_TABLE_F_DORMANT)) != 0 + tab.SetDormant(dormant) + + owner := (attrFlags & uint32(linux.NFT_TABLE_F_OWNER)) != 0 + if owner { + if err := tab.SetOwner(uint32(ms.PortID)); err != nil { + return err + } + } + + return nil +} + +// updateTable updates an existing table. +func (p *Protocol) updateTable(nft *nftables.NFTables, tab *nftables.Table, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, ms *nlmsg.MessageSet) *syserr.AnnotatedError { + var attrFlags uint32 + if uflags, ok := attrs[linux.NFTA_TABLE_FLAGS]; ok { + attrFlags, _ = uflags.Uint32() + // This check needs to be done before table update. + if attrFlags & ^uint32(linux.NFT_TABLE_F_MASK) > 0 { + return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Table flags set are not supported")) + } + } + + // When updating the table, if the table has an owner but the owner flag isn't set, + // the table should not be updated. + // From net/netfilter/nf_tables_api.c:nf_tables_updtable. + if tab.HasOwner() && (attrFlags&uint32(linux.NFT_TABLE_F_OWNER)) == 0 { + return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("Nftables: Table with name: %s already has an owner but NFT_TABLE_F_OWNER was not set when updating the table", tab.GetName())) + } + + // The owner is only updated if the table has no previous owner. + if !tab.HasOwner() && attrFlags&uint32(linux.NFT_TABLE_F_OWNER) != 0 { + if err := tab.SetOwner(uint32(ms.PortID)); err != nil { return err } } + dormant := (attrFlags & uint32(linux.NFT_TABLE_F_DORMANT)) != 0 tab.SetDormant(dormant) return nil } @@ -159,12 +220,16 @@ func (p *Protocol) getTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.Bytes return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Table name attribute is malformed or not found")) } - tab, err := nft.GetTable(family, tabNameBytes.String()) + tab, err := nft.GetTable(family, tabNameBytes.String(), uint32(ms.PortID)) if err != nil { return err } tabName := tab.GetName() + userFlags, err := tab.GetLinuxUserFlagSet() + if err != nil { + return err + } m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: uint16(linux.NFNL_SUBSYS_NFTABLES)<<8 | uint16(linux.NFT_MSG_GETTABLE), }) @@ -176,13 +241,73 @@ func (p *Protocol) getTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.Bytes ResourceID: uint16(0), }) m.PutAttrString(linux.NFTA_TABLE_NAME, tabName) + m.PutAttr(linux.NFTA_TABLE_USE, primitive.AllocateUint32(uint32(tab.ChainCount()))) + m.PutAttr(linux.NFTA_TABLE_HANDLE, primitive.AllocateUint64(tab.GetHandle())) + m.PutAttr(linux.NFTA_TABLE_FLAGS, primitive.AllocateUint8(userFlags)) + + if tab.HasOwner() { + m.PutAttr(linux.NFTA_TABLE_OWNER, primitive.AllocateUint32(tab.GetOwner())) + } + + if tab.HasUserData() { + m.PutAttr(linux.NFTA_TABLE_USERDATA, primitive.AsByteSlice(tab.GetUserData())) + } + return nil } +// deleteTable deletes a table for the given family. +func (p *Protocol) deleteTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, hdr linux.NetlinkMessageHeader, msgType linux.NfTableMsgType, ms *nlmsg.MessageSet) *syserr.AnnotatedError { + if family == stack.Unspec || (!hasAttr(linux.NFTA_TABLE_NAME, attrs) && !hasAttr(linux.NFTA_TABLE_HANDLE, attrs)) { + nft.Flush(attrs, family, uint32(ms.PortID)) + return nil + } + + var tab *nftables.Table + var err *syserr.AnnotatedError + if tabHandleBytes, ok := attrs[linux.NFTA_TABLE_HANDLE]; ok { + tabHandle, ok := tabHandleBytes.Uint64() + if !ok { + return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Table handle attribute is malformed or not found")) + } + + tab, err = nft.GetTableByHandle(family, uint64(tabHandle), uint32(ms.PortID)) + } else { + tabNameBytes, ok := attrs[linux.NFTA_TABLE_NAME] + if !ok { + return syserr.NewAnnotatedError(syserr.ErrInvalidArgument, fmt.Sprintf("Nftables: Table name attribute is malformed or not found")) + } + tab, err = nft.GetTable(family, tabNameBytes.String(), uint32(ms.PortID)) + } + + if err != nil { + // Ignore ENOENT if DESTROY_TABLE is set + if err.GetError() == syserr.ErrNoFileOrDir && msgType == linux.NFT_MSG_DESTROYTABLE { + return nil + } + return err + } + + // Don't delete the table if it is not empty and NLM_F_NONREC is set. + if hdr.Flags&linux.NLM_F_NONREC == linux.NLM_F_NONREC && tab.ChainCount() > 0 { + return syserr.NewAnnotatedError(syserr.ErrBusy, fmt.Sprintf("Nftables: Table with family: %d and name: %s already exists", int(family), tab.GetName())) + } + + _, err = nft.DeleteTable(family, tab.GetName()) + return err +} + +// netLinkMessagePayloadSize returns the size of the netlink message payload. func netLinkMessagePayloadSize(h *linux.NetlinkMessageHeader) int { return int(h.Length) - linux.NetlinkMessageHeaderSize } +// hasAttr returns whether the given attribute key is present in the attribute map. +func hasAttr(attrName uint16, attrs map[uint16]nlmsg.BytesView) bool { + _, ok := attrs[attrName] + return ok +} + // init registers the NETLINK_NETFILTER provider. func init() { netlink.RegisterProvider(linux.NETLINK_NETFILTER, NewProtocol) diff --git a/pkg/sentry/socket/netlink/nlmsg/BUILD b/pkg/sentry/socket/netlink/nlmsg/BUILD index e7b82a85db..15700f7c41 100644 --- a/pkg/sentry/socket/netlink/nlmsg/BUILD +++ b/pkg/sentry/socket/netlink/nlmsg/BUILD @@ -10,11 +10,15 @@ go_library( srcs = [ "message.go", ], - visibility = ["//pkg/sentry:internal"], + visibility = [ + "//pkg/sentry:internal", + "//pkg/tcpip/nftables:__subpackages__", + ], deps = [ "//pkg/abi/linux", "//pkg/bits", "//pkg/hostarch", + "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", ], diff --git a/pkg/sentry/socket/netlink/nlmsg/message.go b/pkg/sentry/socket/netlink/nlmsg/message.go index 60375e3fef..50b48e8400 100644 --- a/pkg/sentry/socket/netlink/nlmsg/message.go +++ b/pkg/sentry/socket/netlink/nlmsg/message.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" ) @@ -252,17 +253,20 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize) if !ok { + log.Debugf("Failed to parse netlink attributes at header stage") return } hdr.UnmarshalUnsafe(hdrBytes) value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) if !ok { + log.Debugf("Failed to parse %d bytes after %d header bytes", int(hdr.Length)-linux.NetlinkAttrHeaderSize, linux.NetlinkAttrHeaderSize) return } _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO)) if !ok { + log.Debugf("Failed to parse netlink attributes at aligning stage") return } @@ -323,6 +327,17 @@ func (v *BytesView) Uint32() (uint32, bool) { return uint32(val), true } +// Uint64 converts the raw attribute value to uint64. +func (v *BytesView) Uint64() (uint64, bool) { + attr := []byte(*v) + val := primitive.Uint64(0) + if len(attr) != val.SizeBytes() { + return 0, false + } + val.UnmarshalBytes(attr) + return uint64(val), true +} + // Int32 converts the raw attribute value to int32. func (v *BytesView) Int32() (int32, bool) { attr := []byte(*v) diff --git a/pkg/tcpip/nftables/BUILD b/pkg/tcpip/nftables/BUILD index f569dc3ebc..7b2a00699d 100644 --- a/pkg/tcpip/nftables/BUILD +++ b/pkg/tcpip/nftables/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/abi/linux", "//pkg/atomicbitops", "//pkg/rand", + "//pkg/sentry/socket/netlink/nlmsg", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/checksum", diff --git a/pkg/tcpip/nftables/nftables.go b/pkg/tcpip/nftables/nftables.go index 50463efe6e..80991c99f5 100644 --- a/pkg/tcpip/nftables/nftables.go +++ b/pkg/tcpip/nftables/nftables.go @@ -19,7 +19,9 @@ import ( "slices" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -241,36 +243,62 @@ func NewNFTables(clock tcpip.Clock, rng rand.RNG) *NFTables { if rng.Reader == nil { panic("nftables state must be initialized with a non-nil random number generator") } - return &NFTables{clock: clock, startTime: clock.Now(), rng: rng} + return &NFTables{clock: clock, startTime: clock.Now(), rng: rng, tableHandleCounter: atomicbitops.Uint64{}} } -// Flush clears entire ruleset and all data for all address families. -func (nf *NFTables) Flush() { - for family := range stack.NumAFs { - nf.filters[family] = nil +// Flush clears the entire ruleset and all data for the given address family, or for all address +// families if the family is unspecified. Tables that are not owned by the given owner are not +// deleted. +func (nf *NFTables) Flush(attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, owner uint32) { + var tabName *string = nil + if nameBytes, ok := attrs[linux.NFTA_TABLE_NAME]; ok { + name := nameBytes.String() + tabName = &name + } + + if family != stack.Unspec { + nf.FlushAddressFamily(family, tabName, owner) + return + } + + for stackFamily := range stack.NumAFs { + nf.FlushAddressFamily(stackFamily, tabName, owner) } } // FlushAddressFamily clears ruleset and all data for the given address family, // returning an error if the address family is invalid. -func (nf *NFTables) FlushAddressFamily(family stack.AddressFamily) *syserr.AnnotatedError { - // Ensures address family is valid. - if err := validateAddressFamily(family); err != nil { - return err +func (nf *NFTables) FlushAddressFamily(family stack.AddressFamily, tabName *string, owner uint32) { + afFilter := nf.filters[family] + if afFilter == nil { + return } - nf.filters[family] = nil - return nil -} + var tablesToDelete []TableInfo + for name, table := range afFilter.tables { + // Caller cannot delete a table they do not own. + if table.HasOwner() && table.GetOwner() != owner { + continue + } -// GetTable validates the inputs and gets a table if it exists, error otherwise. -func (nf *NFTables) GetTable(family stack.AddressFamily, tableName string) (*Table, *syserr.AnnotatedError) { - // Ensures address family is valid. - if err := validateAddressFamily(family); err != nil { - return nil, err + if tabName != nil && *tabName != table.GetName() { + continue + } + + tablesToDelete = append(tablesToDelete, TableInfo{Name: name, Handle: table.GetHandle()}) } + for _, tableData := range tablesToDelete { + delete(afFilter.tables, tableData.Name) + delete(afFilter.tableHandles, tableData.Handle) + } +} + +// GetTable validates the inputs and gets a table if it exists, error otherwise. +func (nf *NFTables) GetTable(family stack.AddressFamily, tableName string, portID uint32) (*Table, *syserr.AnnotatedError) { // Checks if the table map for the address family has been initialized. + // Invalid families will never be initialized as NewTable messages for netfilter sockets + // protect against creating tables for invalid families. if nf.filters[family] == nil || nf.filters[family].tables == nil { return nil, syserr.NewAnnotatedError(syserr.ErrNoFileOrDir, fmt.Sprintf("table map for address family %v has no tables", family)) } @@ -284,6 +312,40 @@ func (nf *NFTables) GetTable(family stack.AddressFamily, tableName string) (*Tab return nil, syserr.NewAnnotatedError(syserr.ErrNoFileOrDir, fmt.Sprintf("table %s not found for address family %v", tableName, family)) } + // If the table has an owner, it must match the Netlink portID of the calling process. + // User space processes only have non-zero port ids. + // Only the kernel can have a zero port id. + if t.HasOwner() && portID != 0 && portID != t.GetOwner() { + return nil, syserr.NewAnnotatedError(syserr.ErrNotPermitted, fmt.Sprintf("table %s has owner %d, which does not match the Netlink portID of the calling process %d", tableName, t.GetOwner(), portID)) + } + + return t, nil +} + +// GetTableByHandle validates the inputs and gets a table by its handle and family if it exists, +// error otherwise. +func (nf *NFTables) GetTableByHandle(family stack.AddressFamily, handle uint64, portID uint32) (*Table, *syserr.AnnotatedError) { + // Checks if the table handle map for the address family has been initialized. + if nf.filters[family] == nil || nf.filters[family].tableHandles == nil { + return nil, syserr.NewAnnotatedError(syserr.ErrNoFileOrDir, fmt.Sprintf("table handle map for address family %v has no tables", family)) + } + + // Gets the corresponding table map for the address family. + tableHandleMap := nf.filters[family].tableHandles + + // Checks if a table with the name exists. + t, exists := tableHandleMap[handle] + if !exists { + return nil, syserr.NewAnnotatedError(syserr.ErrNoFileOrDir, fmt.Sprintf("table with handle %d not found for address family %v", handle, family)) + } + + // If the table has an owner, it must match the Netlink portID of the calling process. + // User space processes only have non-zero port ids. + // Only the kernel can have a zero port id. + if t.HasOwner() && portID != 0 && portID != t.GetOwner() { + return nil, syserr.NewAnnotatedError(syserr.ErrNotPermitted, fmt.Sprintf("table with handle %d has owner %d, which does not match the Netlink portID of the calling process %d", handle, t.GetOwner(), portID)) + } + return t, nil } @@ -304,15 +366,17 @@ func (nf *NFTables) AddTable(family stack.AddressFamily, name string, // Initializes filter if first table for the address family. if nf.filters[family] == nil { nf.filters[family] = &addressFamilyFilter{ - family: family, - nftState: nf, - tables: make(map[string]*Table), - hfStacks: make(map[stack.NFHook]*hookFunctionStack), + family: family, + nftState: nf, + tables: make(map[string]*Table), + tableHandles: make(map[uint64]*Table), + hfStacks: make(map[stack.NFHook]*hookFunctionStack), } } // Gets the corresponding table map for the address family. tableMap := nf.filters[family].tables + tableHandleMap := nf.filters[family].tableHandles // Checks if a table with the same name already exists. If so, returns the // existing table (unless errorOnDuplicate is true). @@ -329,12 +393,19 @@ func (nf *NFTables) AddTable(family stack.AddressFamily, name string, afFilter: nf.filters[family], chains: make(map[string]*Chain), flagSet: make(map[TableFlag]struct{}), + handle: nf.getNewTableHandle(), } tableMap[name] = t + tableHandleMap[t.handle] = t return t, nil } +// getNewTableHandle returns a new table handle for the NFTables object. +func (nf *NFTables) getNewTableHandle() uint64 { + return nf.tableHandleCounter.Add(1) +} + // CreateTable makes a new table for the specified address family like AddTable // but also returns an error if a table by the same name already exists. // Note: this interface mirrors the difference between the create and add @@ -353,7 +424,7 @@ func (nf *NFTables) DeleteTable(family stack.AddressFamily, tableName string) (b } // Gets and checks the table. - t, err := nf.GetTable(family, tableName) + t, err := nf.GetTable(family, tableName, 0) if err != nil { return false, err } @@ -363,15 +434,16 @@ func (nf *NFTables) DeleteTable(family stack.AddressFamily, tableName string) (b t.DeleteChain(chainName) } - // Deletes the table from the table map. + // Deletes the table from the table map and from the table handle map. delete(nf.filters[family].tables, tableName) + delete(nf.filters[family].tableHandles, t.handle) return true, nil } // GetChain validates the inputs and gets a chain if it exists, error otherwise. func (nf *NFTables) GetChain(family stack.AddressFamily, tableName string, chainName string) (*Chain, *syserr.AnnotatedError) { // Gets and checks the table. - t, err := nf.GetTable(family, tableName) + t, err := nf.GetTable(family, tableName, 0) if err != nil { return nil, err } @@ -389,7 +461,7 @@ func (nf *NFTables) GetChain(family stack.AddressFamily, tableName string, chain // Note: if the chain is not a base chain, info should be nil. func (nf *NFTables) AddChain(family stack.AddressFamily, tableName string, chainName string, info *BaseChainInfo, comment string, errorOnDuplicate bool) (*Chain, *syserr.AnnotatedError) { // Gets and checks the table. - t, err := nf.GetTable(family, tableName) + t, err := nf.GetTable(family, tableName, 0) if err != nil { return nil, err } @@ -411,7 +483,7 @@ func (nf *NFTables) CreateChain(family stack.AddressFamily, tableName string, ch // an error if the address family is invalid or the table doesn't exist. func (nf *NFTables) DeleteChain(family stack.AddressFamily, tableName string, chainName string) (bool, *syserr.AnnotatedError) { // Gets and checks the table. - t, err := nf.GetTable(family, tableName) + t, err := nf.GetTable(family, tableName, 0) if err != nil { return false, err } @@ -438,6 +510,56 @@ func (t *Table) GetAddressFamily() stack.AddressFamily { return t.afFilter.family } +// GetHandle returns the handle of the table. +func (t *Table) GetHandle() uint64 { + return t.handle +} + +// GetOwner returns the owner of the table. +func (t *Table) GetOwner() uint32 { + return t.owner +} + +// SetOwner sets the owner of the table. If the table already has an owner, it +// is not updated. +func (t *Table) SetOwner(nlpid uint32) *syserr.AnnotatedError { + // This should only be called once, when setting the owner of a table for the first time. + if t.HasOwner() { + return syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("table %s already has an owner", t.name)) + } + + t.flagSet[TableFlagOwner] = struct{}{} + t.owner = nlpid + return nil +} + +// HasOwner returns whether the table has an owner. +func (t *Table) HasOwner() bool { + _, ok := t.flagSet[TableFlagOwner] + return ok +} + +// GetUserData returns the user data of the table. +func (t *Table) GetUserData() []byte { + return t.userData +} + +// HasUserData returns whether the table has user data. +func (t *Table) HasUserData() bool { + return t.userData != nil +} + +// SetUserData sets the user data of the table. +func (t *Table) SetUserData(data []byte) { + // User data should only be set once. + if t.userData != nil { + return + } + + t.userData = make([]byte, len(data)) + copy(t.userData, data) +} + // IsDormant returns whether the table is dormant. func (t *Table) IsDormant() bool { _, dormant := t.flagSet[TableFlagDormant] @@ -453,6 +575,34 @@ func (t *Table) SetDormant(dormant bool) { } } +// GetLinuxFlagSet returns the flag set of the table. +// Although user flags map to uint8 space, internal flags could eventually be +// supported, which together map to a uint32 space. +func (t *Table) GetLinuxFlagSet() (uint32, *syserr.AnnotatedError) { + var flags uint32 = 0 + for flag := range t.flagSet { + switch flag { + case TableFlagDormant: + flags |= linux.NFT_TABLE_F_DORMANT + case TableFlagOwner: + flags |= linux.NFT_TABLE_F_OWNER + default: + return 0, syserr.NewAnnotatedError(syserr.ErrNotSupported, fmt.Sprintf("unsupported flag %v", flag)) + } + } + + return flags, nil +} + +// GetLinuxUserFlagSet returns the user flag set of the table. +func (t *Table) GetLinuxUserFlagSet() (uint8, *syserr.AnnotatedError) { + flags, err := t.GetLinuxFlagSet() + if err != nil { + return 0, err + } + return uint8(flags & linux.NFT_TABLE_F_MASK), nil +} + // GetChain returns the chain with the specified name if it exists, error // otherwise. func (t *Table) GetChain(chainName string) (*Chain, *syserr.AnnotatedError) { diff --git a/pkg/tcpip/nftables/nftables_types.go b/pkg/tcpip/nftables/nftables_types.go index 3682e92540..d3d73dcc70 100644 --- a/pkg/tcpip/nftables/nftables_types.go +++ b/pkg/tcpip/nftables/nftables_types.go @@ -139,10 +139,11 @@ func validateHook(hook stack.NFHook, family stack.AddressFamily) *syserr.Annotat // NFTables represents the nftables state for all address families. // Note: unlike iptables, nftables doesn't start with any initialized tables. type NFTables struct { - filters [stack.NumAFs]*addressFamilyFilter // Filters for each address family. - clock tcpip.Clock // Clock for timing evaluations. - startTime time.Time // Time NFTables object was created. - rng rand.RNG // Random number generator. + filters [stack.NumAFs]*addressFamilyFilter // Filters for each address family. + clock tcpip.Clock // Clock for timing evaluations. + startTime time.Time // Time NFTables object was created. + rng rand.RNG // Random number generator. + tableHandleCounter atomicbitops.Uint64 // Table handle counter. } // Ensures NFTables implements the NFTablesInterface. @@ -160,6 +161,9 @@ type addressFamilyFilter struct { // tables is a map of tables for each address family. tables map[string]*Table + // tableHandles is a map of table handles (ids) to tables for a given address family. + tableHandles map[uint64]*Table + // hfStacks is a map of hook function stacks (slice of base chains for a // given hook ordered by priority). hfStacks map[stack.NFHook]*hookFunctionStack @@ -179,9 +183,25 @@ type Table struct { // chains is a map of chains for each table. chains map[string]*Chain - // flags is the set of optional flags for the table. + // flagSet is the set of optional flags for the table. // Note: currently nftables only has the single Dormant flag. flagSet map[TableFlag]struct{} + + // handle is the id of the table. + handle uint64 + + // owner is the port id of the table's owner, if it is specified. + owner uint32 + + // userData is the user-specified metadata for the table. This is not used + // by the kernel, but rather userspace applications like nft binary. + userData []byte +} + +// TableInfo represents data between an AFfilter and a Table. +type TableInfo struct { + Name string + Handle uint64 } // hookFunctionStack represents the list of base chains for a specific hook. @@ -198,6 +218,9 @@ const ( // TableFlagDormant is set if the table is dormant. Dormant tables are not // evaluated by the kernel. TableFlagDormant TableFlag = iota + // TableFlagOwner is set if the table has an owner. The owner is the port + // where the table is created. + TableFlagOwner ) // Chain represents a single chain as a list of rules. @@ -745,3 +768,28 @@ func VerdictCodeToString(v uint32) string { } return fmt.Sprintf("invalid verdict: %d", v) } + +// netlinkAFToStackAF maps address families from linux/socket.h to their corresponding +// netfilter address families. +// From linux/include/uapi/linux/netfilter.h +var netlinkAFToStackAF = map[uint8]stack.AddressFamily{ + linux.AF_UNSPEC: stack.Unspec, + linux.AF_UNIX: stack.Inet, + linux.AF_INET: stack.IP, + linux.AF_AX25: stack.Arp, + linux.AF_APPLETALK: stack.Netdev, + linux.AF_BRIDGE: stack.Bridge, + linux.AF_INET6: stack.IP6, +} + +// AFtoNetlinkAF converts a generic address family to a netfilter address family. +// On error, we simply return stack.NumAFs, which will fail validate address family checks later +// on. This is done because Linux does not check these address families for all nftables functions, +// only for certain ones. +func AFtoNetlinkAF(af uint8) stack.AddressFamily { + naf, ok := netlinkAFToStackAF[af] + if !ok { + return stack.NumAFs + } + return naf +} diff --git a/pkg/tcpip/stack/nftables_types.go b/pkg/tcpip/stack/nftables_types.go index 5856c587fd..113d8f1348 100644 --- a/pkg/tcpip/stack/nftables_types.go +++ b/pkg/tcpip/stack/nftables_types.go @@ -95,8 +95,11 @@ func (h NFHook) String() string { type AddressFamily int const ( + // Unspec represents an unspecified address family. + Unspec AddressFamily = iota + // IP represents IPv4 Family. - IP AddressFamily = iota + IP // IP6 represents IPv6 Family. IP6 @@ -119,6 +122,7 @@ const ( // AddressFamilyStrings maps address families to their string representation. var AddressFamilyStrings = map[AddressFamily]string{ + Unspec: "UNSPEC", IP: "IPv4", IP6: "IPv6", Inet: "Internet (Both IPv4/IPv6)", @@ -128,8 +132,9 @@ var AddressFamilyStrings = map[AddressFamily]string{ } // ValidateAddressFamily ensures the family address is valid (within bounds). +// Unspecified address family is not valid. It is only used to reference all address families. func ValidateAddressFamily(family AddressFamily) error { - if family < 0 || family >= NumAFs { + if family < 1 || family >= NumAFs { return fmt.Errorf("invalid address family: %d", int(family)) } return nil diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 48bffefba0..c4d583e33e 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -3692,8 +3692,12 @@ cc_library( cc_library( name = "socket_netlink_netfilter_util", + testonly = 1, srcs = ["socket_netlink_netfilter_util.cc"], hdrs = ["socket_netlink_netfilter_util.h"], + deps = select_gtest() + [ + ":socket_netlink_util", + ], ) cc_binary( diff --git a/test/syscalls/linux/socket_netlink_netfilter.cc b/test/syscalls/linux/socket_netlink_netfilter.cc index 7578028e9d..e33140f021 100644 --- a/test/syscalls/linux/socket_netlink_netfilter.cc +++ b/test/syscalls/linux/socket_netlink_netfilter.cc @@ -12,13 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +// netinet/in.h must be included before netfilter.h. +// clang-format off +#include +#include +#include #include +// clang-format on #include +#include #include +#include #include #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -41,7 +50,6 @@ namespace { constexpr uint32_t kSeq = 12345; using ::testing::_; -using ::testing::Eq; using SockOptTest = ::testing::TestWithParam< std::tuple, std::string>>; @@ -99,134 +107,83 @@ TEST(NetlinkNetfilterTest, CanCreateSocket) { TEST(NetlinkNetfilterTest, AddAndAddTableWithDormantFlag) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - uint16_t default_table_id = 0; const char test_table_name[] = "test_table"; + uint32_t table_flags = NFT_TABLE_F_DORMANT; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; - struct flagAttribute { - struct nlattr attr; - uint32_t flags; - }; - struct request { - struct nlmsghdr hdr; - struct nfgenmsg msg; - struct nameAttribute nattr; - }; - - struct request_2 { - struct nlmsghdr hdr; - struct nfgenmsg msg; - struct nameAttribute nattr; - struct flagAttribute fattr; - }; - - struct request add_tab_req = {}; - InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq, NLM_F_REQUEST | NLM_F_ACK); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req.nattr.attr, sizeof(add_tab_req.nattr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req.nattr.name, sizeof(add_tab_req.nattr.name), - test_table_name); - - struct request_2 add_tab_req_2 = {}; - InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq + 1, NLM_F_REQUEST | NLM_F_ACK); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req_2.nattr.attr, sizeof(add_tab_req_2.nattr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req_2.nattr.name, sizeof(add_tab_req_2.nattr.name), - test_table_name); - InitNetlinkAttr(&add_tab_req_2.fattr.attr, sizeof(add_tab_req_2.fattr.flags), - NFTA_TABLE_FLAGS); - add_tab_req_2.fattr.flags = NFT_TABLE_F_DORMANT; - - ASSERT_NO_ERRNO( - NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); - ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 1, &add_tab_req_2, - sizeof(add_tab_req_2))); + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector add_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Attr(NFTA_TABLE_FLAGS, &table_flags, sizeof(table_flags)) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError( + fd, kSeq + 1, add_request_buffer_2.data(), add_request_buffer_2.size())); } TEST(NetlinkNetfilterTest, AddAndRetrieveNewTable) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - uint16_t default_table_id = 0; - const char test_table_name[] = "test_table"; + const char test_table_name[] = "test_tab_add_retrieve"; + uint32_t table_flags = NFT_TABLE_F_DORMANT | NFT_TABLE_F_OWNER; + uint8_t expected_udata[] = {0x01, 0x02, 0x03, 0x04}; + uint32_t expected_chain_count = 0; + uint32_t expected_flags = table_flags; + size_t expected_udata_size = sizeof(expected_udata); + bool correct_response = false; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; - struct request { - struct nlmsghdr hdr; - struct nfgenmsg msg; - struct nameAttribute attr; - }; - - struct request add_tab_req = {}; - InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq, NLM_F_REQUEST | NLM_F_ACK); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req.attr.attr, sizeof(add_tab_req.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req.attr.name, sizeof(add_tab_req.attr.name), - test_table_name); - - struct request add_tab_req_2 = {}; - bool correct_response = false; - InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_GETTABLE), - kSeq + 1, NLM_F_REQUEST); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req_2.attr.attr, sizeof(add_tab_req_2.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req_2.attr.name, sizeof(add_tab_req_2.attr.name), - test_table_name); - - ASSERT_NO_ERRNO( - NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); + uint32_t expected_owner = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + // Include the null terminator at the end of the string. + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Attr(NFTA_TABLE_FLAGS, &table_flags, sizeof(table_flags)) + .Attr(NFTA_TABLE_USERDATA, expected_udata, sizeof(expected_udata)) + .Build(); + + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + // Don't set NLM_F_ACK here, since the check will be done for every + // nlmsg received. + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &add_tab_req_2, sizeof(add_tab_req_2), + fd, get_request_buffer.data(), get_request_buffer.size(), [&](const struct nlmsghdr* hdr) { - ASSERT_THAT(hdr->nlmsg_type, Eq(MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, - NFT_MSG_GETTABLE))); - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct nfgenmsg))); - const struct nfgenmsg* genmsg = - reinterpret_cast(NLMSG_DATA(hdr)); - EXPECT_EQ(genmsg->nfgen_family, AF_INET); - EXPECT_EQ(genmsg->version, NFNETLINK_V0); - - const struct nfattr* nfattr = FindNfAttr(hdr, genmsg, NFTA_TABLE_NAME); - EXPECT_NE(nullptr, nfattr) << "NFTA_TABLE_NAME not found in message."; - if (nfattr == nullptr) { - return; - } - - std::string name(reinterpret_cast(NFA_DATA(nfattr))); - EXPECT_EQ(name, test_table_name); + CheckNetfilterTableAttributes( + hdr, nullptr, test_table_name, &expected_chain_count, nullptr, + &expected_flags, &expected_owner, expected_udata, + &expected_udata_size, true); correct_response = true; }, false)); @@ -234,207 +191,618 @@ TEST(NetlinkNetfilterTest, AddAndRetrieveNewTable) { ASSERT_TRUE(correct_response); } +TEST(NetlinkNetfilterTest, ErrGettingTableWithDifferentFamily) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + const char test_table_name[] = "test_tab_different_families"; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector add_request_buffer_ipv4 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_IPV4) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector add_request_buffer_ipv6 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_IPV6) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_INET) + .Seq(kSeq + 2) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, + add_request_buffer_ipv4.data(), + add_request_buffer_ipv4.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 1, + add_request_buffer_ipv6.data(), + add_request_buffer_ipv6.size())); + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 2, get_request_buffer.data(), + get_request_buffer.size()), + PosixErrorIs(ENOENT, _)); +} + TEST(NetlinkNetfilterTest, ErrAddExistingTableWithExclusiveFlag) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - uint16_t default_table_id = 0; - const char test_table_name[] = "test_table"; + const char test_table_name[] = "err_exclusive"; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; - struct request { - struct nlmsghdr hdr; - struct nfgenmsg msg; - struct nameAttribute attr; - }; - - struct request add_tab_req = {}; - InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq, NLM_F_REQUEST | NLM_F_ACK); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req.attr.attr, sizeof(add_tab_req.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req.attr.name, sizeof(add_tab_req.attr.name), - test_table_name); - - struct request add_tab_req_2 = {}; - InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq + 1, NLM_F_REQUEST | NLM_F_EXCL); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req_2.attr.attr, sizeof(add_tab_req_2.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req_2.attr.name, sizeof(add_tab_req_2.attr.name), - test_table_name); - - ASSERT_NO_ERRNO( - NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); - ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 1, &add_tab_req_2, - sizeof(add_tab_req_2)), - PosixErrorIs(EEXIST, _)); + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector add_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_EXCL) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq + 1, add_request_buffer_2.data(), + add_request_buffer_2.size()), + PosixErrorIs(EEXIST, _)); } TEST(NetlinkNetfilterTest, ErrAddExistingTableWithReplaceFlag) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - uint16_t default_table_id = 0; - const char test_table_name[] = "test_table"; + const char test_table_name[] = "err_replace"; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; - struct request { - struct nlmsghdr hdr; - struct nfgenmsg msg; - struct nameAttribute attr; - }; - - struct request add_tab_req = {}; - InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq, NLM_F_REQUEST | NLM_F_ACK); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req.attr.attr, sizeof(add_tab_req.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req.attr.name, sizeof(add_tab_req.attr.name), - test_table_name); - - struct request add_tab_req_2 = {}; - InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq + 1, NLM_F_REQUEST | NLM_F_REPLACE); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&add_tab_req_2.attr.attr, sizeof(add_tab_req_2.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(add_tab_req_2.attr.name, sizeof(add_tab_req_2.attr.name), - test_table_name); - - ASSERT_NO_ERRNO( - NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); - ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 1, &add_tab_req_2, - sizeof(add_tab_req_2)), - PosixErrorIs(ENOTSUP, _)); + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector add_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_REPLACE) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq + 1, add_request_buffer_2.data(), + add_request_buffer_2.size()), + PosixErrorIs(ENOTSUP, _)); } TEST(NetlinkNetfilterTest, ErrAddTableWithUnsupportedFamily) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); uint8_t unknown_family = 255; - uint16_t default_table_id = 0; const char test_table_name[] = "unsupported_family_table"; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; - struct request { - struct nlmsghdr hdr; - struct nfgenmsg msg; - struct nameAttribute attr; - }; + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST) + .Family(unknown_family) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size()), + PosixErrorIs(ENOTSUP, _)); +} - struct request get_tab_req = {}; - InitNetlinkHdr(&get_tab_req.hdr, sizeof(get_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), - kSeq, NLM_F_REQUEST); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&get_tab_req.msg, unknown_family, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&get_tab_req.attr.attr, sizeof(get_tab_req.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(get_tab_req.attr.name, sizeof(get_tab_req.attr.name), - test_table_name); +TEST(NetlinkNetfilterTest, ErrAddTableWithUnsupportedFlags) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint32_t unsupported_flags = 0xFFFFFFFF; + const char test_table_name[] = "test_table"; - ASSERT_THAT( - NetlinkRequestAckOrError(fd, kSeq, &get_tab_req, sizeof(get_tab_req)), - PosixErrorIs(ENOTSUP, _)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Attr(NFTA_TABLE_FLAGS, &unsupported_flags, sizeof(unsupported_flags)) + .Build(); + + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size()), + PosixErrorIs(ENOTSUP, _)); } TEST(NetlinkNetfilterTest, ErrRetrieveNoSpecifiedNameTable) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - uint16_t default_table_id = 0; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; - struct request { - struct nlmsghdr hdr; - struct nfgenmsg msg; - }; - - struct request get_tab_req = {}; - InitNetlinkHdr(&get_tab_req.hdr, sizeof(get_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_GETTABLE), - kSeq, NLM_F_REQUEST); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&get_tab_req.msg, AF_INET, NFNETLINK_V0, - default_table_id); + std::vector get_request_buffer = NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .Build(); - ASSERT_THAT( - NetlinkRequestAckOrError(fd, kSeq, &get_tab_req, sizeof(get_tab_req)), - PosixErrorIs(EINVAL, _)); + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq, get_request_buffer.data(), + get_request_buffer.size()), + PosixErrorIs(EINVAL, _)); } TEST(NetlinkNetfilterTest, ErrRetrieveNonexistentTable) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - uint16_t default_table_id = 0; const char test_table_name[] = "undefined_table"; FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); - struct nameAttribute { - struct nlattr attr; - char name[32]; - }; + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq, get_request_buffer.data(), + get_request_buffer.size()), + PosixErrorIs(ENOENT, _)); +} + +TEST(NetlinkNetfilterTest, ErrRetrieveTableWithOwnerMismatch) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + const char test_table_name[] = "test_table"; + uint32_t table_flags = NFT_TABLE_F_DORMANT | NFT_TABLE_F_OWNER; + uint8_t expected_udata[3] = {0x01, 0x02, 0x03}; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + FileDescriptor fd_2 = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Attr(NFTA_TABLE_FLAGS, &table_flags, sizeof(table_flags)) + .Attr(NFTA_TABLE_USERDATA, expected_udata, sizeof(expected_udata)) + .Build(); + + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + + ASSERT_THAT( + NetlinkRequestAckOrError(fd_2, kSeq + 1, get_request_buffer.data(), + get_request_buffer.size()), + PosixErrorIs(EPERM, _)); +} + +TEST(NetlinkNetfilterTest, DeleteExistingTableByName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // uint16_t default_table_id = 0; + const char test_table_name[] = "test_table_name_delete"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + struct request { struct nlmsghdr hdr; struct nfgenmsg msg; - struct nameAttribute attr; + struct nameAttribute nattr; }; - struct request get_tab_req = {}; - InitNetlinkHdr(&get_tab_req.hdr, sizeof(get_tab_req), - MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_GETTABLE), - kSeq, NLM_F_REQUEST); - // For both ipv4 and ipv6 tables. - InitNetfilterGenmsg(&get_tab_req.msg, AF_INET, NFNETLINK_V0, - default_table_id); - // Attribute setting - InitNetlinkAttr(&get_tab_req.attr.attr, sizeof(get_tab_req.attr.name), - NFTA_TABLE_NAME); - absl::SNPrintF(get_tab_req.attr.name, sizeof(get_tab_req.attr.name), - test_table_name); + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector del_request_buffer = + NlReq() + .MsgType(NFT_MSG_DELTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError( + fd, kSeq + 1, del_request_buffer.data(), del_request_buffer.size())); +} + +TEST(NetlinkNetfilterTest, DeleteTableByHandle) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // Retrieve the actual table handle from the kernel with a GET request. + uint64_t expected_handle = 0; + const char test_table_name[] = "test_table_handle_delete"; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + + // Retrieve the table handle from the kernel. + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, get_request_buffer.data(), get_request_buffer.size(), + [&](const struct nlmsghdr* hdr) { + const nfattr* attr = FindNfAttr(hdr, nullptr, NFTA_TABLE_HANDLE); + EXPECT_NE(attr, nullptr); + EXPECT_EQ(attr->nfa_type, NFTA_TABLE_HANDLE); + EXPECT_EQ(attr->nfa_len - NLA_HDRLEN, sizeof(expected_handle)); + expected_handle = *reinterpret_cast(NFA_DATA(attr)); + }, + false)); + EXPECT_NE(expected_handle, 0); + + std::vector del_request_buffer = + NlReq() + .MsgType(NFT_MSG_DELTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 2) + .Attr(NFTA_TABLE_HANDLE, &expected_handle, sizeof(expected_handle)) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError( + fd, kSeq + 2, del_request_buffer.data(), del_request_buffer.size())); +} + +TEST(NetlinkNetfilterTest, ErrDeleteNonexistentTable) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + const char test_table_name[] = "nonexistent_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector del_request_buffer = + NlReq() + .MsgType(NFT_MSG_DELTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 1, del_request_buffer.data(), + del_request_buffer.size()), + PosixErrorIs(ENOENT, _)); +} + +TEST(NetlinkNetfilterTest, DestroyNonexistentTable) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + const char test_table_name[] = "nonexistent_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector destroy_request_buffer = + NlReq() + .MsgType(NFT_MSG_DESTROYTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 1, + destroy_request_buffer.data(), + destroy_request_buffer.size())); +} + +TEST(NetlinkNetfilterTest, DeleteAllTablesUnspecifiedFamily) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + char test_table_name_inet[] = "test_table_inet"; + char test_table_name_bridge[] = "test_table_bridge"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name_inet) + .Build(); + + std::vector add_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name_bridge) + .Build(); + + std::vector destroy_request_buffer = + NlReq() + .MsgType(NFT_MSG_DELTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_UNSPEC) + .Seq(kSeq + 2) + .Build(); + + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 3) + .StrAttr(NFTA_TABLE_NAME, test_table_name_inet) + .Build(); + + std::vector get_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 4) + .StrAttr(NFTA_TABLE_NAME, test_table_name_bridge) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError( + fd, kSeq + 1, add_request_buffer_2.data(), add_request_buffer_2.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 2, + destroy_request_buffer.data(), + destroy_request_buffer.size())); + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 3, get_request_buffer.data(), + get_request_buffer.size()), + PosixErrorIs(ENOENT, _)); + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq + 4, get_request_buffer_2.data(), + get_request_buffer_2.size()), + PosixErrorIs(ENOENT, _)); +} + +TEST(NetlinkNetfilterTest, DeleteAllTablesUnspecifiedFamilySpecifiedTableName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + char test_table_name_same[] = "test_same_name_table"; + char test_table_name_different[] = "test_different_name_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + std::vector add_request_buffer_inet = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name_same) + .Build(); + + std::vector add_request_buffer_bridge = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name_same) + .Build(); + + std::vector add_request_buffer_different_bridge = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 2) + .StrAttr(NFTA_TABLE_NAME, test_table_name_different) + .Build(); + + std::vector destroy_request_buffer = + NlReq() + .MsgType(NFT_MSG_DELTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_UNSPEC) + .Seq(kSeq + 3) + .StrAttr(NFTA_TABLE_NAME, test_table_name_same) + .Build(); + + std::vector get_request_buffer_inet = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq + 4) + .StrAttr(NFTA_TABLE_NAME, test_table_name_same) + .Build(); + + std::vector get_request_buffer_bridge = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 5) + .StrAttr(NFTA_TABLE_NAME, test_table_name_same) + .Build(); + + std::vector get_request_buffer_different = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 6) + .StrAttr(NFTA_TABLE_NAME, test_table_name_different) + .Build(); + + bool correct_response = false; + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, + add_request_buffer_inet.data(), + add_request_buffer_inet.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 1, + add_request_buffer_bridge.data(), + add_request_buffer_bridge.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError( + fd, kSeq + 2, add_request_buffer_different_bridge.data(), + add_request_buffer_different_bridge.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 3, + destroy_request_buffer.data(), + destroy_request_buffer.size())); + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq + 4, get_request_buffer_inet.data(), + get_request_buffer_inet.size()), + PosixErrorIs(ENOENT, _)); + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq + 5, get_request_buffer_bridge.data(), + get_request_buffer_bridge.size()), + PosixErrorIs(ENOENT, _)); + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, get_request_buffer_different.data(), + get_request_buffer_different.size(), + [&](const struct nlmsghdr* hdr) { + const struct nfattr* table_name_attr = + FindNfAttr(hdr, nullptr, NFTA_TABLE_NAME); + EXPECT_NE(table_name_attr, nullptr); + EXPECT_EQ(table_name_attr->nfa_type, NFTA_TABLE_NAME); + std::string name( + reinterpret_cast(NFA_DATA(table_name_attr))); + EXPECT_EQ(name, test_table_name_different); + correct_response = true; + }, + false)); + + ASSERT_TRUE(correct_response); +} + +TEST(NetlinkNetfilterTest, DeleteAllTablesUnspecifiedNameAndHandle) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + char test_table_name_inet[] = "test_table_inet"; + char test_table_name_bridge[] = "test_table_bridge"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + std::vector add_request_buffer = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_INET) + .Seq(kSeq) + .StrAttr(NFTA_TABLE_NAME, test_table_name_inet) + .Build(); + + std::vector add_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_NEWTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 1) + .StrAttr(NFTA_TABLE_NAME, test_table_name_bridge) + .Build(); + + std::vector destroy_request_buffer = + NlReq() + .MsgType(NFT_MSG_DELTABLE) + .Flags(NLM_F_REQUEST | NLM_F_ACK) + .Family(NFPROTO_UNSPEC) + .Seq(kSeq + 2) + .Build(); + + std::vector get_request_buffer = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_INET) + .Seq(kSeq + 3) + .StrAttr(NFTA_TABLE_NAME, test_table_name_inet) + .Build(); + + std::vector get_request_buffer_2 = + NlReq() + .MsgType(NFT_MSG_GETTABLE) + .Flags(NLM_F_REQUEST) + .Family(NFPROTO_BRIDGE) + .Seq(kSeq + 4) + .StrAttr(NFTA_TABLE_NAME, test_table_name_bridge) + .Build(); + + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, add_request_buffer.data(), + add_request_buffer.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError( + fd, kSeq + 1, add_request_buffer_2.data(), add_request_buffer_2.size())); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 2, + destroy_request_buffer.data(), + destroy_request_buffer.size())); + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 3, get_request_buffer.data(), + get_request_buffer.size()), + PosixErrorIs(ENOENT, _)); ASSERT_THAT( - NetlinkRequestAckOrError(fd, kSeq, &get_tab_req, sizeof(get_tab_req)), + NetlinkRequestAckOrError(fd, kSeq + 4, get_request_buffer_2.data(), + get_request_buffer_2.size()), PosixErrorIs(ENOENT, _)); } diff --git a/test/syscalls/linux/socket_netlink_netfilter_util.cc b/test/syscalls/linux/socket_netlink_netfilter_util.cc index 54a7ebef0d..01e0e24578 100644 --- a/test/syscalls/linux/socket_netlink_netfilter_util.cc +++ b/test/syscalls/linux/socket_netlink_netfilter_util.cc @@ -14,11 +14,94 @@ #include "test/syscalls/linux/socket_netlink_netfilter_util.h" +#include #include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_util.h" namespace gvisor { namespace testing { +NlReq& NlReq::MsgType(uint8_t message_type) { + message_type_ = message_type; + return *this; +} + +NlReq& NlReq::Flags(uint16_t flags) { + flags_ = flags; + return *this; +} + +NlReq& NlReq::Seq(uint32_t seq) { + seq_ = seq; + return *this; +} + +NlReq& NlReq::Family(uint8_t family) { + family_ = family; + return *this; +} + +// Method to add an attribute to the message. payload_size must be the size of +// the payload in bytes. +NlReq& NlReq::Attr(uint16_t attr_type, const void* payload, + size_t payload_size) { + // Store a pointer to the payload and the size to construct it later. + attributes_[attr_type] = {reinterpret_cast(payload), + payload_size}; + return *this; +} + +// Method to add a string attribute to the message. +// The payload is expected to be a null-terminated string. +NlReq& NlReq::StrAttr(uint16_t attr_type, const char* payload) { + attributes_[attr_type] = {payload, strlen(payload) + 1}; + return *this; +} + +std::vector NlReq::Build() { + size_t aligned_hdr_size = NLMSG_ALIGN(sizeof(nlmsghdr)); + size_t aligned_genmsg_size = NLMSG_ALIGN(sizeof(nfgenmsg)); + size_t total_attr_size = 0; + + for (const auto& [attr_type, attr_data] : attributes_) { + const auto& [_, payload_size] = attr_data; + total_attr_size += NLA_ALIGN(NLA_HDRLEN + payload_size); + } + + size_t total_message_len = + NLMSG_ALIGN(aligned_hdr_size + aligned_genmsg_size + total_attr_size); + + msg_buffer_.resize(total_message_len); + std::memset(msg_buffer_.data(), 0, total_message_len); + + struct nlmsghdr* nlh = reinterpret_cast(msg_buffer_.data()); + InitNetlinkHdr(nlh, (uint32_t)total_message_len, + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, message_type_), seq_, + flags_); + + struct nfgenmsg* nfg = reinterpret_cast(NLMSG_DATA(nlh)); + InitNetfilterGenmsg(nfg, family_, NFNETLINK_V0, 0); + + char* payload = + (char*)msg_buffer_.data() + aligned_hdr_size + aligned_genmsg_size; + + for (const auto& [attr_type, attr_data] : attributes_) { + const auto& [payload_data, payload_size] = attr_data; + struct nlattr* attr = reinterpret_cast(payload); + InitNetlinkAttr(attr, payload_size, attr_type); + std::memcpy((char*)attr + NLA_HDRLEN, payload_data, payload_size); + // Move over to the next attribute. + payload += NLA_ALIGN(NLA_HDRLEN + payload_size); + } + return msg_buffer_; +} + // Helper function to initialize a nfgenmsg header. void InitNetfilterGenmsg(struct nfgenmsg* genmsg, uint8_t family, uint8_t version, uint16_t res_id) { @@ -27,5 +110,81 @@ void InitNetfilterGenmsg(struct nfgenmsg* genmsg, uint8_t family, genmsg->res_id = res_id; } +// Helper function to check the netfilter table attributes. +void CheckNetfilterTableAttributes( + const struct nlmsghdr* hdr, const struct nfgenmsg* genmsg, + const char* test_table_name, uint32_t* expected_chain_count, + uint64_t* expected_handle, uint32_t* expected_flags, + uint32_t* expected_owner, uint8_t* expected_udata, + size_t* expected_udata_size, bool skip_handle_check) { + // Check for the NFTA_TABLE_NAME attribute. + const struct nfattr* table_name_attr = + FindNfAttr(hdr, genmsg, NFTA_TABLE_NAME); + if (table_name_attr != nullptr && test_table_name != nullptr) { + std::string name(reinterpret_cast(NFA_DATA(table_name_attr))); + EXPECT_EQ(name, test_table_name); + } else { + EXPECT_EQ(table_name_attr, nullptr); + EXPECT_EQ(test_table_name, nullptr); + } + + // Check for the NFTA_TABLE_USE attribute. + const struct nfattr* table_use_attr = FindNfAttr(hdr, genmsg, NFTA_TABLE_USE); + if (table_use_attr != nullptr && expected_chain_count != nullptr) { + uint32_t count = *(reinterpret_cast(NFA_DATA(table_use_attr))); + EXPECT_EQ(count, *expected_chain_count); + } else { + EXPECT_EQ(table_use_attr, nullptr); + EXPECT_EQ(expected_chain_count, nullptr); + } + + if (!skip_handle_check) { + // Check for the NFTA_TABLE_HANDLE attribute. + const struct nfattr* handle_attr = + FindNfAttr(hdr, genmsg, NFTA_TABLE_HANDLE); + if (handle_attr != nullptr && expected_handle != nullptr) { + uint64_t handle = *(reinterpret_cast(NFA_DATA(handle_attr))); + EXPECT_EQ(handle, *expected_handle); + } else { + EXPECT_EQ(handle_attr, nullptr); + EXPECT_EQ(expected_handle, nullptr); + } + } + + // Check for the NFTA_TABLE_FLAGS attribute. + const struct nfattr* flags_attr = FindNfAttr(hdr, genmsg, NFTA_TABLE_FLAGS); + if (flags_attr != nullptr && expected_flags != nullptr) { + uint32_t flags = *(reinterpret_cast(NFA_DATA(flags_attr))); + EXPECT_EQ(flags, *expected_flags); + } else { + EXPECT_EQ(flags_attr, nullptr); + EXPECT_EQ(expected_flags, nullptr); + } + + // Check for the NFTA_TABLE_OWNER attribute. + const struct nfattr* owner_attr = FindNfAttr(hdr, genmsg, NFTA_TABLE_OWNER); + if (owner_attr != nullptr) { + uint32_t owner = *(reinterpret_cast(NFA_DATA(owner_attr))); + EXPECT_EQ(owner, *expected_owner); + } else { + EXPECT_EQ(owner_attr, nullptr); + EXPECT_EQ(expected_owner, nullptr); + } + + // Check for the NFTA_TABLE_USERDATA attribute. + const struct nfattr* user_data_attr = + FindNfAttr(hdr, genmsg, NFTA_TABLE_USERDATA); + + if (user_data_attr != nullptr && expected_udata_size != nullptr) { + uint8_t user_data[VALID_USERDATA_SIZE] = {}; + EXPECT_EQ(user_data_attr->nfa_len - NLA_HDRLEN, *expected_udata_size); + std::memcpy(user_data, NFA_DATA(user_data_attr), *expected_udata_size); + EXPECT_EQ(memcmp(user_data, expected_udata, *expected_udata_size), 0); + } else { + EXPECT_EQ(user_data_attr, nullptr); + EXPECT_EQ(expected_udata_size, nullptr); + } +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_netfilter_util.h b/test/syscalls/linux/socket_netlink_netfilter_util.h index 8241ed3c12..4a9426c5f6 100644 --- a/test/syscalls/linux/socket_netlink_netfilter_util.h +++ b/test/syscalls/linux/socket_netlink_netfilter_util.h @@ -15,20 +15,96 @@ #ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_NETFILTER_UTIL_H_ #define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_NETFILTER_UTIL_H_ +#include +#include +#include +#ifndef NFTA_TABLE_OWNER +#define NFTA_TABLE_OWNER NFTA_TABLE_USERDATA + 1 +#endif + +#ifndef NFT_TABLE_F_OWNER +#define NFT_TABLE_F_OWNER 2 +#endif + +#ifndef NFT_MSG_DESTROYTABLE +#define NFT_MSG_DESTROYTABLE 26 +#endif + #include #include #include #include +// clang-format off +#include +#include +// clang-format on #include +#include #include namespace gvisor { namespace testing { +#define TABLE_NAME_SIZE 32 +#define VALID_USERDATA_SIZE 128 + +struct nameAttribute { + struct nlattr attr; + char name[TABLE_NAME_SIZE]; +}; +struct flagAttribute { + struct nlattr attr; + uint32_t flags; +}; +struct userDataAttribute { + struct nlattr attr; + uint8_t userdata[VALID_USERDATA_SIZE]; +}; +struct deleteAttribute { + struct nlattr attr; + uint32_t handle; +}; + void InitNetfilterGenmsg(struct nfgenmsg* genmsg, uint8_t family, uint8_t version, uint16_t res_id); +void CheckNetfilterTableAttributes( + const struct nlmsghdr* hdr, const struct nfgenmsg* genmsg, + const char* test_table_name, uint32_t* expected_chain_count, + uint64_t* expected_handle, uint32_t* expected_flags, + uint32_t* expected_owner, uint8_t* expected_udata, + size_t* expected_udata_size, bool skip_handle_check); + +class NlReq { + public: + NlReq() = default; + + NlReq& MsgType(uint8_t message_type); + NlReq& Flags(uint16_t flags); + NlReq& Seq(uint32_t seq); + NlReq& Family(uint8_t family); + + // Method to add an attribute to the message. If there is a default + // size for the attribute type, it will be used. + // Otherwise, assumes the payload is of at least size payload_size. + NlReq& Attr(uint16_t attr_type, const void* payload, size_t payload_size); + + // Method to add a string attribute to the message. + // The payload is expected to be a null-terminated string. + NlReq& StrAttr(uint16_t attr_type, const char* payload); + + std::vector Build(); + + private: + uint8_t message_type_ = 0; + uint16_t flags_ = 0; + uint32_t seq_ = 0; + uint8_t family_ = 0; + std::map> attributes_ = {}; + std::vector msg_buffer_; +}; + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index d4ad57d6a2..c45b94afcc 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -211,7 +211,7 @@ uint16_t MakeNetlinkMsgType(uint8_t subsys_id, uint8_t msg_type) { // Helper function to initialize a netlink header. void InitNetlinkHdr(struct nlmsghdr* hdr, uint32_t msg_len, uint16_t msg_type, uint32_t seq, uint16_t flags) { - hdr->nlmsg_len = msg_len; + hdr->nlmsg_len = NLMSG_ALIGN(msg_len); hdr->nlmsg_type = msg_type; hdr->nlmsg_flags = flags; hdr->nlmsg_seq = seq; @@ -220,7 +220,7 @@ void InitNetlinkHdr(struct nlmsghdr* hdr, uint32_t msg_len, uint16_t msg_type, // Helper function to initialize a netlink attribute. void InitNetlinkAttr(struct nlattr* attr, int payload_size, uint16_t attr_type) { - attr->nla_len = NLA_HDRLEN + payload_size; + attr->nla_len = NLA_ALIGN(NLA_HDRLEN + payload_size); attr->nla_type = attr_type; }