Skip to content

Commit 79f17a5

Browse files
amogh09xxx0624
andauthored
Prevent network-blackhole-port from affecting TMDS access (#4403)
* Protect TMDS IP from being affected by network-blackhole-port fault * Fix test --------- Co-authored-by: xingzhen <[email protected]>
1 parent f7dfa32 commit 79f17a5

File tree

6 files changed

+117
-35
lines changed

6 files changed

+117
-35
lines changed

agent/handlers/task_server_setup_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3806,6 +3806,8 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
38063806
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
38073807
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
38083808
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
3809+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
3810+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
38093811
)
38103812
}
38113813
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go

Lines changed: 37 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
3131
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
3232
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
33+
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds"
3334
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types"
3435
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils"
3536
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4"
@@ -56,7 +57,7 @@ const (
5657
requestTimeoutSeconds = 5
5758
// Commands that will be used to start/stop/check fault.
5859
iptablesNewChainCmd = "iptables -w %d -N %s"
59-
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s --dport %s -j DROP"
60+
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s -d %s --dport %s -j %s"
6061
iptablesInsertChainCmd = "iptables -w %d -I %s -j %s"
6162
iptablesChainExistCmd = "iptables -w %d -C %s -p %s --dport %s -j DROP"
6263
iptablesClearChainCmd = "iptables -w %d -F %s"
@@ -71,6 +72,9 @@ const (
7172
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1"
7273
tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:"
7374
tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio"
75+
allIPv4CIDR = "0.0.0.0/0"
76+
dropTarget = "DROP"
77+
acceptTarget = "ACCEPT"
7478
)
7579

7680
type FaultHandler struct {
@@ -220,24 +224,42 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
220224
"taskArn": taskArn,
221225
})
222226

223-
// Appending a new rule based on the protocol and port number from the request body
224-
appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, requestTimeoutSeconds, chain, protocol, port)
225-
cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " "))
226-
if err != nil {
227-
logger.Error("Unable to append rule to chain", logger.Fields{
228-
"netns": netNs,
229-
"command": appendRuleCmdString,
227+
// Helper function to run iptables rule change commands
228+
var execRuleChangeCommand = func(cmdString string) (string, error) {
229+
// Appending a new rule based on the protocol and port number from the request body
230+
cmdOutput, err = h.runExecCommand(ctx, strings.Split(cmdString, " "))
231+
if err != nil {
232+
logger.Error("Unable to add rule to chain", logger.Fields{
233+
"netns": netNs,
234+
"command": cmdString,
235+
"output": string(cmdOutput),
236+
"taskArn": taskArn,
237+
"error": err,
238+
})
239+
return string(cmdOutput), err
240+
}
241+
logger.Info("Successfully added new rule to iptable chain", logger.Fields{
242+
"command": cmdString,
230243
"output": string(cmdOutput),
231244
"taskArn": taskArn,
232-
"error": err,
233245
})
234-
return string(cmdOutput), err
246+
return "", nil
247+
}
248+
249+
// Add a rule to accept all traffic to TMDS
250+
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
251+
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
252+
acceptTarget)
253+
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
254+
return out, err
255+
}
256+
257+
// Add a rule to drop all traffic to the port that the fault targets
258+
faultRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
259+
requestTimeoutSeconds, chain, protocol, allIPv4CIDR, port, dropTarget)
260+
if out, err := execRuleChangeCommand(faultRuleCmdString); err != nil {
261+
return out, err
235262
}
236-
logger.Info("Successfully appended new rule to iptable chain", logger.Fields{
237-
"command": appendRuleCmdString,
238-
"output": string(cmdOutput),
239-
"taskArn": taskArn,
240-
})
241263

242264
// Inserting the chain into the built-in INPUT/OUTPUT table
243265
insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, requestTimeoutSeconds, insertTable, chain)

ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
521521
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
522522
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
523523
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
524+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
525+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
524526
)
525527
},
526528
},
@@ -554,6 +556,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
554556
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
555557
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
556558
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
559+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
560+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
557561
)
558562
},
559563
},
@@ -578,7 +582,7 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
578582
},
579583
},
580584
{
581-
name: fmt.Sprintf("%s fail append rule to chain", startNetworkBlackHolePortTestPrefix),
585+
name: fmt.Sprintf("%s fail append ACCEPT rule to chain", startNetworkBlackHolePortTestPrefix),
582586
expectedStatusCode: 500,
583587
requestBody: happyBlackHolePortReqBody,
584588
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
@@ -603,6 +607,34 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
603607
)
604608
},
605609
},
610+
{
611+
name: fmt.Sprintf("%s fail append DROP rule to chain", startNetworkBlackHolePortTestPrefix),
612+
expectedStatusCode: 500,
613+
requestBody: happyBlackHolePortReqBody,
614+
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
615+
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
616+
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
617+
Return(happyTaskResponse, nil).
618+
Times(1)
619+
},
620+
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
621+
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
622+
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
623+
gomock.InOrder(
624+
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
625+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
626+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
627+
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
628+
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
629+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
630+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
631+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
632+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
633+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
634+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
635+
)
636+
},
637+
},
606638
{
607639
name: fmt.Sprintf("%s fail insert chain to table", startNetworkBlackHolePortTestPrefix),
608640
expectedStatusCode: 500,

ecs-agent/tmds/server.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ import (
2929

3030
const (
3131
// TMDS IP and port
32-
IPv4 = "127.0.0.1"
33-
Port = 51679
32+
IPv4 = "127.0.0.1"
33+
Port = 51679
34+
IPForTasks = "169.254.170.2"
35+
PortForTasks = "80"
3436
)
3537

3638
// IPv4 address for TMDS

0 commit comments

Comments
 (0)