Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@ type streamSession interface {
decrypt([]byte) error
}

type newStream struct {
readStream readStream
payloadType uint8
}

type session struct {
localContextMutex sync.Mutex
localContext, remoteContext *Context
localOptions, remoteOptions []ContextOption

newStream chan readStream
newStream chan newStream
acceptStreamTimeout time.Time

started chan interface{}
Expand Down
10 changes: 5 additions & 5 deletions session_srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //n
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
newStream: make(chan newStream),
acceptStreamTimeout: config.AcceptStreamTimeout,
started: make(chan interface{}),
closed: make(chan interface{}),
Expand Down Expand Up @@ -93,17 +93,17 @@ func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) {

// AcceptStream returns a stream to handle RTCP for a single SSRC
func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
stream, ok := <-s.newStream
newStream, ok := <-s.newStream
if !ok {
return nil, 0, errStreamAlreadyClosed
}

readStream, ok := stream.(*ReadStreamSRTCP)
readStream, ok := newStream.readStream.(*ReadStreamSRTCP)
if !ok {
return nil, 0, errFailedTypeAssertion
}

return readStream, stream.GetSSRC(), nil
return readStream, readStream.GetSSRC(), nil
}

// Close ends the session
Expand Down Expand Up @@ -172,7 +172,7 @@ func (s *SessionSRTCP) decrypt(buf []byte) error {
if !s.session.acceptStreamTimeout.IsZero() {
_ = s.session.nextConn.SetReadDeadline(time.Time{})
}
s.session.newStream <- r // Notify AcceptStream
s.session.newStream <- newStream{readStream: r} // Notify AcceptStream
}

readStream, ok := r.(*ReadStreamSRTCP)
Expand Down
24 changes: 16 additions & 8 deletions session_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nol
localOptions: localOpts,
remoteOptions: remoteOpts,
readStreams: map[uint32]readStream{},
newStream: make(chan readStream),
newStream: make(chan newStream),
acceptStreamTimeout: config.AcceptStreamTimeout,
started: make(chan interface{}),
closed: make(chan interface{}),
Expand Down Expand Up @@ -93,19 +93,26 @@ func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) {
return nil, errFailedTypeAssertion
}

// AcceptStream returns a stream to handle RTCP for a single SSRC
// AcceptStream returns a stream to handle RTP for a single SSRC
func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) {
stream, ok := <-s.newStream
readStream, ssrc, _, err := s.AcceptStreamWithPayloadType()
return readStream, ssrc, err
}

// AcceptStreamWithPayloadType returns a stream to handle RTP for a single SSRC.
// It returns the payload type as well as the SSRC.
func (s *SessionSRTP) AcceptStreamWithPayloadType() (*ReadStreamSRTP, uint32, uint8, error) {
newStream, ok := <-s.newStream
if !ok {
return nil, 0, errStreamAlreadyClosed
return nil, 0, 0, errStreamAlreadyClosed
}

readStream, ok := stream.(*ReadStreamSRTP)
readStream, ok := newStream.readStream.(*ReadStreamSRTP)
if !ok {
return nil, 0, errFailedTypeAssertion
return nil, 0, 0, errFailedTypeAssertion
}

return readStream, stream.GetSSRC(), nil
return readStream, readStream.GetSSRC(), newStream.payloadType, nil
}

// Close ends the session
Expand Down Expand Up @@ -178,7 +185,8 @@ func (s *SessionSRTP) decrypt(buf []byte) error {
if !s.session.acceptStreamTimeout.IsZero() {
_ = s.session.nextConn.SetReadDeadline(time.Time{})
}
s.session.newStream <- r // Notify AcceptStream
// notify AcceptStream
s.session.newStream <- newStream{readStream: r, payloadType: h.PayloadType}
}

readStream, ok := r.(*ReadStreamSRTP)
Expand Down
40 changes: 40 additions & 0 deletions stream_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"io"
"sync"
"sync/atomic"
"time"

"github.com/pion/rtp"
Expand All @@ -27,6 +28,10 @@ type ReadStreamSRTP struct {
isInited bool

buffer io.ReadWriteCloser

peekedPacket []byte
peekedPacketMu sync.Mutex
peekedPacketPresent atomic.Bool
}

// Used by getOrCreateReadStream
Expand Down Expand Up @@ -74,8 +79,43 @@ func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) {
return n, err
}

// Peek reads the next full RTP packet from the nextConn, but queues it internally.
// The next call to Read (or the next call to Peek without a call to Read in between)
// will return the same packet again.
func (r *ReadStreamSRTP) Peek(buf []byte) (int, error) {
r.peekedPacketMu.Lock()
defer r.peekedPacketMu.Unlock()
if r.peekedPacketPresent.Load() {
return copy(buf, r.peekedPacket), nil
}
n, err := r.buffer.Read(buf)
if err != nil {
return n, err
}
if cap(r.peekedPacket) < n {
size := 1500
if size < n {
size = n
}
r.peekedPacket = make([]byte, size)
}
r.peekedPacket = r.peekedPacket[:n]
copy(r.peekedPacket, buf)
r.peekedPacketPresent.Store(true)
return n, nil
}

// Read reads and decrypts full RTP packet from the nextConn
func (r *ReadStreamSRTP) Read(buf []byte) (int, error) {
if r.peekedPacketPresent.Load() {
r.peekedPacketMu.Lock()
if r.peekedPacketPresent.Swap(false) {
n := copy(buf, r.peekedPacket)
r.peekedPacketMu.Unlock()
return n, nil
}
r.peekedPacketMu.Unlock()
}
return r.buffer.Read(buf)
}

Expand Down