diff --git a/tcpip/transport/icmp/endpoint.go b/tcpip/transport/icmp/endpoint.go index d6bdda7d..6a8a777a 100644 --- a/tcpip/transport/icmp/endpoint.go +++ b/tcpip/transport/icmp/endpoint.go @@ -86,7 +86,7 @@ type endpoint struct { route stack.Route } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (*endpoint, *tcpip.Error) { return &endpoint{ stack: stack, netProto: netProto, @@ -365,7 +365,7 @@ func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { data = data[header.ICMPv4EchoMinimumSize:] // Linux performs these basic checks. - if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { + if (icmpv4.Type() != header.ICMPv4Echo && icmpv4.Type() != header.ICMPv4EchoReply) || icmpv4.Code() != 0 { return tcpip.ErrInvalidEndpointState } @@ -659,7 +659,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv switch e.netProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(vv.First()) - if h.Type() != header.ICMPv4EchoReply { + if h.Type() != header.ICMPv4EchoReply && h.Type() != header.ICMPv4Echo { e.stack.Stats().DroppedPackets.Increment() return } diff --git a/tcpip/transport/icmp/forwarder.go b/tcpip/transport/icmp/forwarder.go new file mode 100644 index 00000000..c16b6eba --- /dev/null +++ b/tcpip/transport/icmp/forwarder.go @@ -0,0 +1,97 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp + +import ( + "github.com/google/netstack/tcpip" + "github.com/google/netstack/tcpip/buffer" + "github.com/google/netstack/tcpip/stack" + "github.com/google/netstack/waiter" +) + +// Forwarder is a session request forwarder, which allows clients to decide +// what to do with a session request, for example: ignore it, or process it. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + handler func(*ForwarderRequest) + + stack *stack.Stack +} + +// NewForwarder allocates and initializes a new forwarder. +func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { + return &Forwarder{ + stack: s, + handler: handler, + } +} + +// HandlePacket handles all packets. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + + f.handler(&ForwarderRequest{ + stack: f.stack, + route: r, + id: id, + vv: vv, + }) + + return true +} + +// ForwarderRequest represents a session request received by the forwarder and +// passed to the client. Clients may optionally create an endpoint to represent +// it via CreateEndpoint. +type ForwarderRequest struct { + stack *stack.Stack + route *stack.Route + id stack.TransportEndpointID + vv buffer.VectorisedView +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the session request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.id +} + +// CreateEndpoint creates a connected UDP endpoint for the session request. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep, _ := newEndpoint(r.stack, r.route.NetProto, ProtocolNumber4, queue) + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber4, r.id, ep, true); err != nil { + ep.Close() + return nil, err + } + + ep.id = r.id + ep.route = r.route.Clone() + // ep.dstPort = r.id.RemotePort + ep.regNICID = r.route.NICID() + + ep.state = stateConnected + + ep.rcvMu.Lock() + ep.rcvReady = true + ep.rcvMu.Unlock() + + ep.HandlePacket(r.route, r.id, r.vv) + + return ep, nil +}