From bf8e0133609cbf7084d498dd55b0e08f2d677bb9 Mon Sep 17 00:00:00 2001 From: Alexander Cueva Date: Fri, 8 Aug 2025 13:20:07 -0700 Subject: [PATCH] Add functions deep copy and replace NFTables structure. This change adds a DeepCopy method to copy all the elements of an NFTables struct and a method to replace the data of a NFTables struct. Needed for atomic, rollback-able changes when processing batch messages. PiperOrigin-RevId: 792738351 --- pkg/tcpip/nftables/nftables.go | 10 +++ pkg/tcpip/nftables/nftables_types.go | 130 +++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/pkg/tcpip/nftables/nftables.go b/pkg/tcpip/nftables/nftables.go index 6f89e692a3..8af6eab470 100644 --- a/pkg/tcpip/nftables/nftables.go +++ b/pkg/tcpip/nftables/nftables.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg" "gvisor.dev/gvisor/pkg/syserr" @@ -270,6 +271,15 @@ func (nf *NFTables) Flush(attrs map[uint16]nlmsg.BytesView, owner uint32) { continue } + // TODO: b/434242152 - Support correctly deleting chains once + // rules are deletable. + for chainName := range table.chains { + ok := table.DeleteChain(chainName) + if !ok { + log.Warningf("Failed to delete chain %s", chainName) + } + } + tablesToDelete = append(tablesToDelete, TableInfo{Name: name, Handle: table.GetHandle()}) } diff --git a/pkg/tcpip/nftables/nftables_types.go b/pkg/tcpip/nftables/nftables_types.go index f9c3af153d..bb729e535a 100644 --- a/pkg/tcpip/nftables/nftables_types.go +++ b/pkg/tcpip/nftables/nftables_types.go @@ -388,6 +388,7 @@ type Chain struct { // BaseChainInfo stores hook-related info for attaching a chain to the pipeline. type BaseChainInfo struct { + // LINT.IfChange(base_chain_info) // BcType is the base chain type of the chain (filter, nat, route). BcType BaseChainType @@ -414,6 +415,8 @@ type BaseChainInfo struct { // explicitly accepted or rejected by the rules. A chain's policy defaults to // Accept, but this can be used to specify otherwise. PolicyDrop bool + + // LINT.ThenChange(:base_chain_info_copy) } // PolicyBoolToValue converts the policy drop boolean to a uint8. @@ -1109,3 +1112,130 @@ func HasAttr(attrName uint16, attrs map[uint16]nlmsg.BytesView) bool { _, ok := attrs[attrName] return ok } + +// deepCopyRule returns a deep copy of the Rule struct. +func deepCopyRule(rule *Rule, chainCopy *Chain) *Rule { + return &Rule{ + chain: chainCopy, + // Because the underlying op data within the slice cannot be + // modified, creating a shallow copy is sufficient. Even if the + // original struct is modified and an operation is dropped, + // the copy will hold a reference to the original operation, + // preventing it from being destroyed. + ops: slices.Clone(rule.ops), + handle: rule.handle, + udata: slices.Clone(rule.udata), + } +} + +// deepCopyChain returns a deep copy of the Chain struct. +func deepCopyChain(chain *Chain, tableCopy *Table) *Chain { + chainCopy := &Chain{ + name: chain.name, + table: tableCopy, + handle: chain.handle, + flags: chain.flags, + handleToRule: make(map[uint64]*Rule), + userData: slices.Clone(chain.userData), + chainUse: chain.chainUse, + bound: chain.bound, + comment: chain.comment, + } + + // LINT.IfChange(base_chain_info_copy) + + // BaseChainInfo is immutable after creation and it only contains + // primitives, so we can safely copy it. + if chain.baseChainInfo != nil { + chainCopy.baseChainInfo = &BaseChainInfo{} + *chainCopy.baseChainInfo = *chain.baseChainInfo + } + + // LINT.ThenChange() + + for _, rule := range chain.rules { + ruleCopy := deepCopyRule(rule, chainCopy) + chainCopy.rules = append(chainCopy.rules, ruleCopy) + chainCopy.handleToRule[ruleCopy.handle] = ruleCopy + } + return chainCopy +} + +// deepCopyTable returns a deep copy of the Table struct. +func deepCopyTable(table *Table, afFilter *addressFamilyFilter) *Table { + tableCopy := &Table{ + name: table.name, + afFilter: afFilter, + chains: make(map[string]*Chain), + chainHandles: make(map[uint64]*Chain), + flagSet: make(map[TableFlag]struct{}), + handle: table.handle, + owner: table.owner, + userData: slices.Clone(table.userData), + } + tableCopy.handleCounter.Store(table.handleCounter.Load()) + + for flag := range table.flagSet { + tableCopy.flagSet[flag] = struct{}{} + } + + for chainName, chain := range table.chains { + chainCopy := deepCopyChain(chain, tableCopy) + tableCopy.chains[chainName] = chainCopy + tableCopy.chainHandles[chainCopy.handle] = chainCopy + } + return tableCopy +} + +// DeepCopy returns a deep copy of the NFTables struct. +// Assumes that the caller has already locked the mutex. +// ********************************************************************** +// TODO: b/436922484: Add a transaction system to avoid deep copying the entire +// NFTables structure. +// ********************************************************************** +func (nf *NFTables) DeepCopy() *NFTables { + nftCopy := &NFTables{ + clock: nf.clock, + startTime: nf.startTime, + rng: nf.rng, + tableHandleCounter: atomicbitops.Uint64{}, + } + + nftCopy.tableHandleCounter.Store(nf.tableHandleCounter.Load()) + for i, filter := range nf.filters { + if filter == nil { + continue + } + + nftCopy.filters[i] = &addressFamilyFilter{ + family: filter.family, + nftState: nftCopy, + tables: make(map[string]*Table), + tableHandles: make(map[uint64]*Table), + hfStacks: make(map[stack.NFHook]*hookFunctionStack), + } + + for tableName, table := range filter.tables { + tableCopy := deepCopyTable(table, nftCopy.filters[i]) + nftCopy.filters[i].tables[tableName] = tableCopy + nftCopy.filters[i].tableHandles[tableCopy.handle] = tableCopy + } + + for hook, hfStack := range filter.hfStacks { + hfStackCopy := &hookFunctionStack{ + hook: hfStack.hook, + } + for _, chain := range hfStack.baseChains { + hfStackCopy.baseChains = append(hfStackCopy.baseChains, nftCopy.filters[i].tables[chain.table.name].chains[chain.name]) + } + nftCopy.filters[i].hfStacks[hook] = hfStackCopy + } + } + return nftCopy +} + +// ReplaceNFTables replaces the tables of the NFTables struct +// with the tables of the passed in NFTables struct. +func (nf *NFTables) ReplaceNFTables(nftCopy *NFTables) { + nf.filters = nftCopy.filters +}