diff --git a/internal/protocols/webrtc/incoming_track.go b/internal/protocols/webrtc/incoming_track.go index 1ec8e436a5b..83e5fb5a4c8 100644 --- a/internal/protocols/webrtc/incoming_track.go +++ b/internal/protocols/webrtc/incoming_track.go @@ -18,6 +18,15 @@ const ( mimeTypeL16 = "audio/L16" ) +func incomingTrackTWCCExtensionID(params webrtc.RTPParameters) uint8 { + for _, ext := range params.HeaderExtensions { + if ext.URI == twccExtensionURI { + return uint8(ext.ID) + } + } + return 0 +} + var incomingVideoCodecs = []webrtc.RTPCodecParameters{ { RTPCodecCapability: webrtc.RTPCodecCapability{ @@ -244,12 +253,30 @@ type IncomingTrack struct { writeRTCP func([]rtcp.Packet) error log logger.Writer + twccExtID uint8 inboundRTPPacketsLost *counterdumper.Dumper rtpReceiver *rtpreceiver.Receiver } func (t *IncomingTrack) initialize() { t.OnPacketRTP = func(*rtp.Packet) {} + t.twccExtID = incomingTrackTWCCExtensionID(t.receiver.GetParameters()) +} + +func (t *IncomingTrack) stripTWCCExtension(pkt *rtp.Packet) { + if t.twccExtID == 0 || pkt.GetExtension(t.twccExtID) == nil { + return + } + + err := pkt.DelExtension(t.twccExtID) + if err != nil { + panic(err) + } + + if len(pkt.GetExtensionIDs()) == 0 { + pkt.Extension = false + pkt.ExtensionProfile = 0 + } } // Codec returns the track codec. @@ -358,6 +385,7 @@ func (t *IncomingTrack) start() { continue } + t.stripTWCCExtension(pkt) t.OnPacketRTP(pkt) } } diff --git a/internal/protocols/webrtc/peer_connection.go b/internal/protocols/webrtc/peer_connection.go index c6c4127061d..28aecd87b28 100644 --- a/internal/protocols/webrtc/peer_connection.go +++ b/internal/protocols/webrtc/peer_connection.go @@ -22,7 +22,8 @@ import ( ) const ( - webrtcStreamID = "mediamtx" + webrtcStreamID = "mediamtx" + twccExtensionURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" ) func interfaceIPs(interfaceList []string) ([]string, error) { diff --git a/internal/protocols/webrtc/peer_connection_test.go b/internal/protocols/webrtc/peer_connection_test.go index 59a47dd1383..f8266e882df 100644 --- a/internal/protocols/webrtc/peer_connection_test.go +++ b/internal/protocols/webrtc/peer_connection_test.go @@ -34,6 +34,15 @@ func gatherCodecs(tracks []*IncomingTrack) []webrtc.RTPCodecParameters { return codecs } +func senderHeaderExtensionID(params webrtc.RTPSendParameters, uri string) uint8 { + for _, ext := range params.HeaderExtensions { + if ext.URI == uri { + return uint8(ext.ID) + } + } + return 0 +} + func TestPeerConnectionCloseImmediately(t *testing.T) { pc := &PeerConnection{ LocalRandomUDP: true, @@ -500,15 +509,9 @@ func TestPeerConnectionReadSimulcast(t *testing.T) { err = reader.WaitUntilConnected(10 * time.Second) require.NoError(t, err) - var midExtID, ridExtID uint8 - for _, ext := range transceiver.Sender().GetParameters().HeaderExtensions { - switch ext.URI { - case sdp.SDESMidURI: - midExtID = uint8(ext.ID) - case sdp.SDESRTPStreamIDURI: - ridExtID = uint8(ext.ID) - } - } + params := transceiver.Sender().GetParameters() + midExtID := senderHeaderExtensionID(params, sdp.SDESMidURI) + ridExtID := senderHeaderExtensionID(params, sdp.SDESRTPStreamIDURI) require.NotZero(t, midExtID) require.NotZero(t, ridExtID) @@ -577,6 +580,96 @@ func TestPeerConnectionReadSimulcast(t *testing.T) { require.Equal(t, []string{"h", "l", "m"}, rids) } +func TestPeerConnectionStripIncomingTWCC(t *testing.T) { + pub, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pub.Close() //nolint:errcheck + + videoTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + ClockRate: 90000, + }, + "video", "publisher", + webrtc.WithRTPStreamID("l"), + ) + require.NoError(t, err) + + videoSender, err := pub.AddTrack(videoTrack) + require.NoError(t, err) + + reader := &PeerConnection{ + LocalRandomUDP: true, + IPsFromInterfaces: true, + Publish: false, + Log: test.NilLogger, + } + err = reader.Start() + require.NoError(t, err) + defer reader.Close() + + offer, err := pub.CreateOffer(nil) + require.NoError(t, err) + + err = pub.SetLocalDescription(offer) + require.NoError(t, err) + + answer, err := reader.CreateFullAnswer(&offer) + require.NoError(t, err) + + err = pub.SetRemoteDescription(*answer) + require.NoError(t, err) + + err = reader.WaitUntilConnected(10 * time.Second) + require.NoError(t, err) + + params := videoSender.GetParameters() + twccExtID := senderHeaderExtensionID(params, twccExtensionURI) + require.NotZero(t, twccExtID) + + go func() { + time.Sleep(200 * time.Millisecond) + + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: true, + PayloadType: 96, + SequenceNumber: 55421, + Timestamp: 45343, + SSRC: 124123, + }, + Payload: []byte{5, 2}, + } + + pkt.ExtensionProfile = 0xBEDE + require.NoError(t, pkt.SetExtension(twccExtID, []byte{0x12, 0x34})) + + err2 := videoTrack.WriteRTP(pkt) + if err2 != nil { + return + } + }() + + err = reader.GatherIncomingTracks(5 * time.Second) + require.NoError(t, err) + + tracks := reader.IncomingTracks() + require.Len(t, tracks, 1) + + done := make(chan struct{}) + + tracks[0].OnPacketRTP = func(p *rtp.Packet) { + require.False(t, p.Extension) + require.Empty(t, p.Extensions) + close(done) + } + + reader.StartReading() + + <-done +} + func TestPeerConnectionPublishRead(t *testing.T) { pc1 := &PeerConnection{ LocalRandomUDP: true,