@@ -10,12 +10,29 @@ package nclient4
1010
1111import (
1212 "errors"
13+ "fmt"
1314 "io"
1415 "net"
1516
17+ "github.com/mdlayher/arp"
1618 "github.com/mdlayher/ethernet"
1719 "github.com/mdlayher/raw"
1820 "github.com/u-root/uio/uio"
21+ "github.com/vishvananda/netlink"
22+ )
23+
24+ // UDPConnType indicates the type of the udp conn.
25+ type UDPConnType int
26+
27+ const (
28+ // UDPBroadcast specifies the type of udp conn as broadcast.
29+ //
30+ // All the packets will be broadcasted.
31+ UDPBroadcast UDPConnType = 0
32+
33+ // UDPUnicast specifies the type of udp conn as unicast.
34+ // All the packets will be sent to a unicast MAC address.
35+ UDPUnicast UDPConnType = 1
1936)
2037
2138var (
@@ -28,13 +45,16 @@ var (
2845var (
2946 // ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr".
3047 ErrUDPAddrIsRequired = errors .New ("must supply UDPAddr" )
48+
49+ // ErrHWAddrNotFound is an error used when getting MAC address failed.
50+ ErrHWAddrNotFound = errors .New ("hardware address not found" )
3151)
3252
33- // NewRawUDPConn returns a UDP connection bound to the interface and port
34- // given based on a raw packet socket. All packets are broadcasted.
53+ // NewRawUDPConn returns a UDP connection bound to the interface and udp addr
54+ // given based on a raw packet socket.
3555//
3656// The interface can be completely unconfigured.
37- func NewRawUDPConn (iface string , port int ) (net.PacketConn , error ) {
57+ func NewRawUDPConn (iface string , addr * net. UDPAddr , typ UDPConnType ) (net.PacketConn , error ) {
3858 ifc , err := net .InterfaceByName (iface )
3959 if err != nil {
4060 return nil , err
@@ -43,7 +63,12 @@ func NewRawUDPConn(iface string, port int) (net.PacketConn, error) {
4363 if err != nil {
4464 return nil , err
4565 }
46- return NewBroadcastUDPConn (rawConn , & net.UDPAddr {Port : port }), nil
66+
67+ if typ == UDPUnicast {
68+ return NewUnicastRawUDPConn (rawConn , addr ), nil
69+ }
70+
71+ return NewBroadcastUDPConn (rawConn , addr ), nil
4772}
4873
4974// BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast
@@ -157,3 +182,76 @@ func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
157182 // Broadcasting is not always right, but hell, what the ARP do I know.
158183 return upc .PacketConn .WriteTo (packet , & raw.Addr {HardwareAddr : BroadcastMac })
159184}
185+
186+ // UnicastRawUDPConn inherits from BroadcastRawUDPConn and override the WriteTo method
187+ type UnicastRawUDPConn struct {
188+ * BroadcastRawUDPConn
189+ }
190+
191+ // NewUnicastRawUDPConn returns a PacketConn
192+ func NewUnicastRawUDPConn (rawPacketConn net.PacketConn , boundAddr * net.UDPAddr ) net.PacketConn {
193+ return & UnicastRawUDPConn {
194+ BroadcastRawUDPConn : NewBroadcastUDPConn (rawPacketConn , boundAddr ).(* BroadcastRawUDPConn ),
195+ }
196+ }
197+
198+ // WriteTo implements net.PacketConn.WriteTo.
199+ //
200+ // WriteTo try to get the MAC address of destination IP address before
201+ // unicast all packets at the raw socket level.
202+ func (upc * UnicastRawUDPConn ) WriteTo (b []byte , addr net.Addr ) (int , error ) {
203+ udpAddr , ok := addr .(* net.UDPAddr )
204+ if ! ok {
205+ return 0 , ErrUDPAddrIsRequired
206+ }
207+
208+ // Using the boundAddr is not quite right here, but it works.
209+ packet := udp4pkt (b , udpAddr , upc .boundAddr )
210+ dstMac , err := getHwAddr (udpAddr .IP )
211+ if err != nil {
212+ return 0 , ErrHWAddrNotFound
213+ }
214+
215+ return upc .PacketConn .WriteTo (packet , & raw.Addr {HardwareAddr : dstMac })
216+ }
217+
218+ // getHwAddr from local arp cache. If no existing, try to get it by arp protocol.
219+ func getHwAddr (ip net.IP ) (net.HardwareAddr , error ) {
220+ neighList , err := netlink .NeighListExecute (netlink.Ndmsg {
221+ Family : netlink .FAMILY_V4 ,
222+ State : netlink .NUD_REACHABLE ,
223+ })
224+ if err != nil {
225+ return nil , err
226+ }
227+
228+ for _ , neigh := range neighList {
229+ if ip .Equal (neigh .IP ) && neigh .HardwareAddr != nil {
230+ return neigh .HardwareAddr , nil
231+ }
232+ }
233+
234+ return arpResolve (ip )
235+ }
236+
237+ func arpResolve (dest net.IP ) (net.HardwareAddr , error ) {
238+ // auto match the interface based on routes
239+ routes , err := netlink .RouteGet (dest )
240+ if err != nil {
241+ return nil , err
242+ }
243+ if len (routes ) == 0 {
244+ return nil , fmt .Errorf ("no route to %s found" , dest .String ())
245+ }
246+ ifc , err := net .InterfaceByIndex (routes [0 ].LinkIndex )
247+ if err != nil {
248+ return nil , err
249+ }
250+
251+ c , err := arp .Dial (ifc )
252+ if err != nil {
253+ return nil , err
254+ }
255+
256+ return c .Resolve (dest )
257+ }
0 commit comments