Skip to content

Commit

Permalink
Implement SO_TIMESTAMP
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 195047018
  • Loading branch information
iangudger authored and shentubot committed May 2, 2018
1 parent 7e6b108 commit c5b543a
Show file tree
Hide file tree
Showing 32 changed files with 345 additions and 150 deletions.
4 changes: 2 additions & 2 deletions pkg/dhcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
// DHCPOFFER
for {
var addr tcpip.FullAddress
v, err := epin.Read(&addr)
v, _, err := epin.Read(&addr)
if err == tcpip.ErrWouldBlock {
select {
case <-ch:
Expand Down Expand Up @@ -216,7 +216,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error
// DHCPACK
for {
var addr tcpip.FullAddress
v, err := epin.Read(&addr)
v, _, err := epin.Read(&addr)
if err == tcpip.ErrWouldBlock {
select {
case <-ch:
Expand Down
2 changes: 1 addition & 1 deletion pkg/dhcp/dhcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestDHCP(t *testing.T) {
}
}()

s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName})
s := stack.New(&tcpip.StdClock{}, []string{ipv4.ProtocolName}, []string{udp.ProtocolName})

const nicid tcpip.NICID = 1
if err := s.CreateNIC(nicid, id); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/dhcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (s *Server) reader(ctx context.Context) {

for {
var addr tcpip.FullAddress
v, err := s.ep.Read(&addr)
v, _, err := s.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
select {
case <-ch:
Expand Down
2 changes: 1 addition & 1 deletion pkg/sentry/fs/host/socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestSocketSendMsgLen0(t *testing.T) {
defer sfile.DecRef()

s := sfile.FileOperations.(socket.Socket)
n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, unix.ControlMessages{})
n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, socket.ControlMessages{})
if n != 0 {
t.Fatalf("socket sendmsg() failed: %v wrote: %d", terr, n)
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/sentry/kernel/kernel.go
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,15 @@ func (k *Kernel) SetExitError(err error) {
}
}

// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
func (k *Kernel) NowNanoseconds() int64 {
now, err := k.timekeeper.GetTime(sentrytime.Realtime)
if err != nil {
panic("Kernel.NowNanoseconds: " + err.Error())
}
return now
}

// SupervisorContext returns a Context with maximum privileges in k. It should
// only be used by goroutines outside the control of the emulated kernel
// defined by e.
Expand Down
1 change: 1 addition & 0 deletions pkg/sentry/socket/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ go_library(
"//pkg/sentry/usermem",
"//pkg/state",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/transport/unix",
],
)
35 changes: 35 additions & 0 deletions pkg/sentry/socket/control/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,31 @@ func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte {
return alignSlice(buf, align)
}

func putCmsgStruct(buf []byte, msgType uint32, align uint, data interface{}) []byte {
if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
return buf
}
ob := buf

buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader))
buf = putUint32(buf, linux.SOL_SOCKET)
buf = putUint32(buf, msgType)

hdrBuf := buf

buf = binary.Marshal(buf, usermem.ByteOrder, data)

// Check if we went over.
if cap(buf) != cap(ob) {
return hdrBuf
}

// Fix up length.
putUint64(ob, uint64(len(buf)-len(ob)))

return alignSlice(buf, align)
}

// Credentials implements SCMCredentials.Credentials.
func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
// "When a process's user and group IDs are passed over a UNIX domain
Expand Down Expand Up @@ -261,6 +286,16 @@ func alignSlice(buf []byte, align uint) []byte {
return buf[:aligned]
}

// PackTimestamp packs a SO_TIMESTAMP socket control message.
func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SO_TIMESTAMP,
t.Arch().Width(),
linux.NsecToTimeval(timestamp),
)
}

// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.ControlMessages, error) {
var (
Expand Down
69 changes: 47 additions & 22 deletions pkg/sentry/socket/epsocket/epsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ type SocketOperations struct {
// readMu protects access to readView, control, and sender.
readMu sync.Mutex `state:"nosave"`
readView buffer.View
readCM tcpip.ControlMessages
sender tcpip.FullAddress
}

Expand Down Expand Up @@ -210,12 +211,13 @@ func (s *SocketOperations) fetchReadView() *syserr.Error {
s.readView = nil
s.sender = tcpip.FullAddress{}

v, err := s.Endpoint.Read(&s.sender)
v, cms, err := s.Endpoint.Read(&s.sender)
if err != nil {
return syserr.TranslateNetstackError(err)
}

s.readView = v
s.readCM = cms

return nil
}
Expand All @@ -230,7 +232,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
if dst.NumBytes() == 0 {
return 0, nil
}
n, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
n, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
return int64(n), syserror.ErrWouldBlock
}
Expand Down Expand Up @@ -552,6 +554,18 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
}

return linux.NsecToTimeval(s.RecvTimeout()), nil

case linux.SO_TIMESTAMP:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}

var v tcpip.TimestampOption
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}

return int32(v), nil
}

case syscall.SOL_TCP:
Expand Down Expand Up @@ -659,6 +673,14 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n
binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
s.SetRecvTimeout(v.ToNsecCapped())
return nil

case linux.SO_TIMESTAMP:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}

v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TimestampOption(v)))
}

case syscall.SOL_TCP:
Expand Down Expand Up @@ -823,7 +845,9 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
}

// nonBlockingRead issues a non-blocking read.
func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, *syserr.Error) {
//
// TODO: Support timestamps for stream sockets.
func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()

// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
Expand All @@ -839,29 +863,29 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
s.readMu.Unlock()
return n, nil, 0, err
return n, nil, 0, socket.ControlMessages{}, err
}

s.readMu.Lock()
defer s.readMu.Unlock()

if err := s.fetchReadView(); err != nil {
return 0, nil, 0, err
return 0, nil, 0, socket.ControlMessages{}, err
}

if !isPacket && peek && trunc {
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// amount that could be read.
var rql tcpip.ReceiveQueueSizeOption
if err := s.Endpoint.GetSockOpt(&rql); err != nil {
return 0, nil, 0, syserr.TranslateNetstackError(err)
return 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := len(s.readView) + int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, nil, 0, nil
return available, nil, 0, socket.ControlMessages{}, nil
}
return bufLen, nil, 0, nil
return bufLen, nil, 0, socket.ControlMessages{}, nil
}

n, err := dst.CopyOut(ctx, s.readView)
Expand All @@ -874,17 +898,18 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if peek {
if l := len(s.readView); trunc && l > n {
// isPacket must be true.
return l, addr, addrLen, syserr.FromError(err)
return l, addr, addrLen, socket.ControlMessages{IP: s.readCM}, syserr.FromError(err)
}

if isPacket || err != nil {
return int(n), addr, addrLen, syserr.FromError(err)
return int(n), addr, addrLen, socket.ControlMessages{IP: s.readCM}, syserr.FromError(err)
}

// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
n, err := s.Endpoint.Peek(dsts)
n, _, err := s.Endpoint.Peek(dsts)
// TODO: Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
}
Expand All @@ -895,7 +920,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// We got some data, so no need to return an error.
err = nil
}
return int(n), nil, 0, syserr.FromError(err)
return int(n), nil, 0, socket.ControlMessages{IP: s.readCM}, syserr.FromError(err)
}

var msgLen int
Expand All @@ -908,23 +933,23 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
}

if trunc {
return msgLen, addr, addrLen, syserr.FromError(err)
return msgLen, addr, addrLen, socket.ControlMessages{IP: s.readCM}, syserr.FromError(err)
}

return int(n), addr, addrLen, syserr.FromError(err)
return int(n), addr, addrLen, socket.ControlMessages{IP: s.readCM}, syserr.FromError(err)
}

// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages unix.ControlMessages, err *syserr.Error) {
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0

peek := flags&linux.MSG_PEEK != 0
if senderRequested && !s.isPacketBased() {
// Stream sockets ignore the sender address.
senderRequested = false
}
n, senderAddr, senderAddrLen, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return
}
Expand All @@ -936,25 +961,25 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)

for {
n, senderAddr, senderAddrLen, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
if err != syserr.ErrWouldBlock {
return
}

if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {
return 0, nil, 0, unix.ControlMessages{}, syserr.ErrTryAgain
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
return 0, nil, 0, unix.ControlMessages{}, syserr.FromError(err)
return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}

// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
// Reject control messages.
if !controlMessages.Empty() {
func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Reject Unix control messages.
if !controlMessages.Unix.Empty() {
return 0, syserr.ErrInvalidArgument
}

Expand Down
10 changes: 6 additions & 4 deletions pkg/sentry/socket/hostinet/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type socketOperations struct {
queue waiter.Queue
}

var _ = socket.Socket(&socketOperations{})

func newSocketFile(ctx context.Context, fd int, nonblock bool) (*fs.File, *syserr.Error) {
s := &socketOperations{fd: fd}
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
Expand Down Expand Up @@ -339,14 +341,14 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}

// RecvMsg implements socket.Socket.RecvMsg.
func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, unix.ControlMessages, *syserr.Error) {
func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME: We can't support MSG_ERRQUEUE because it uses ancillary
// messages that netstack/tcpip/transport/unix doesn't understand. Kill the
// Socket interface's dependence on netstack.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
return 0, nil, 0, unix.ControlMessages{}, syserr.ErrInvalidArgument
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}

var senderAddr []byte
Expand Down Expand Up @@ -411,11 +413,11 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}

return int(n), senderAddr, uint32(len(senderAddr)), unix.ControlMessages{}, syserr.FromError(err)
return int(n), senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
}

// SendMsg implements socket.Socket.SendMsg.
func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Whitelist flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
Expand Down
Loading

0 comments on commit c5b543a

Please sign in to comment.