Skip to content

Commit 54c02d4

Browse files
committed
feat: Add RDMA device support with CDI integration
Implements RDMA capability detection and device exposure using the Mellanox rdmamap library. Adds discovery of RDMA devices (mlx5_*), protocol detection (RoCE/InfiniBand/iWARP), and automatic injection of character devices (/dev/infiniband/*) into containers via CDI. Key changes: - Add RdmaProvider interface with rdmamap integration - Expose RDMA character devices (uverbs, umad, issm, rdma_cm) - Set environment variables for device paths and names - Add comprehensive unit tests for RDMA functionality Matches feature parity with sriov-network-device-plugin for RDMA. Tested on Mellanox ConnectX hardware with RoCE. Signed-off-by: Fred Rolland <frolland@nvidia.com>
1 parent efe2945 commit 54c02d4

File tree

7 files changed

+295
-0
lines changed

7 files changed

+295
-0
lines changed

pkg/devicestate/state.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,67 @@ func (s *Manager) applyConfigOnDevice(ctx context.Context, ifNameIndex *int, cla
214214
})
215215
}
216216

217+
// If device is RDMA capable, add RDMA character devices
218+
if rdmaCapableAttr, ok := deviceInfo.Attributes[consts.AttributeRDMACapable]; ok && rdmaCapableAttr.BoolValue != nil && *rdmaCapableAttr.BoolValue {
219+
rdmaDevices, err := host.GetHelpers().GetRDMADeviceForPCI(pciAddress)
220+
if err != nil {
221+
logger.Error(err, "Failed to get RDMA devices for PCI address",
222+
"device", pciAddress)
223+
} else if len(rdmaDevices) > 0 {
224+
logger.V(2).Info("Device is RDMA capable, adding RDMA character devices",
225+
"device", pciAddress, "rdmaDevices", rdmaDevices)
226+
227+
for _, rdmaDevice := range rdmaDevices {
228+
// Get character devices for this RDMA device
229+
charDevices, err := host.GetHelpers().GetRDMACharDevices(rdmaDevice)
230+
if err != nil {
231+
logger.Error(err, "Failed to get RDMA character devices, skipping",
232+
"device", pciAddress, "rdmaDevice", rdmaDevice)
233+
continue
234+
}
235+
236+
if len(charDevices) == 0 {
237+
logger.V(2).Info("No RDMA character devices found",
238+
"device", pciAddress, "rdmaDevice", rdmaDevice)
239+
continue
240+
}
241+
242+
// Use RDMA device name in env var key to support multiple RDMA devices
243+
rdmaDeviceSanitized := strings.ReplaceAll(rdmaDevice, "_", "")
244+
devicePrefix := strings.ReplaceAll(result.Device, "-", "_")
245+
246+
// Add each character device to the CDI spec
247+
for _, charDev := range charDevices {
248+
deviceNodes = append(deviceNodes, &cdispec.DeviceNode{
249+
Path: charDev,
250+
HostPath: charDev,
251+
Type: "c", // character device
252+
})
253+
254+
// Add environment variable for each character device type
255+
// Include RDMA device name to avoid collisions with multiple RDMA devices
256+
switch {
257+
case strings.Contains(charDev, "uverbs"):
258+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_UVERBS=%s", devicePrefix, rdmaDeviceSanitized, charDev))
259+
case strings.Contains(charDev, "umad"):
260+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_UMAD=%s", devicePrefix, rdmaDeviceSanitized, charDev))
261+
case strings.Contains(charDev, "issm"):
262+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_ISSM=%s", devicePrefix, rdmaDeviceSanitized, charDev))
263+
case strings.Contains(charDev, "rdma_cm"):
264+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_CM=%s", devicePrefix, rdmaDeviceSanitized, charDev))
265+
}
266+
}
267+
268+
logger.Info("Added RDMA character devices for device",
269+
"device", pciAddress, "rdmaDevice", rdmaDevice, "charDevices", charDevices)
270+
271+
// Add RDMA device name to environment variables
272+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_DEVICE=%s",
273+
devicePrefix, rdmaDeviceSanitized, rdmaDevice))
274+
}
275+
}
276+
}
277+
217278
edits := &cdispec.ContainerEdits{
218279
Env: envs,
219280
DeviceNodes: deviceNodes,

pkg/devicestate/state_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55

66
. "github.com/onsi/ginkgo/v2"
77
. "github.com/onsi/gomega"
8+
"k8s.io/utils/ptr"
89

910
"github.com/k8snetworkplumbingwg/dra-driver-sriov/pkg/consts"
1011
resourceapi "k8s.io/api/resource/v1"
@@ -44,4 +45,38 @@ var _ = Describe("Manager", func() {
4445
Expect(exists).To(BeFalse())
4546
})
4647
})
48+
49+
Context("RDMA Device Preparation", func() {
50+
It("should skip RDMA preparation when device is not RDMA capable", func() {
51+
// Create device without RDMA capability
52+
nonRdmaDevice := &resourceapi.Device{
53+
Name: "0000-08-00-1",
54+
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
55+
consts.AttributePciAddress: {
56+
StringValue: ptr.To("0000:08:00.1"),
57+
},
58+
consts.AttributeRDMACapable: {
59+
BoolValue: ptr.To(false),
60+
},
61+
},
62+
}
63+
64+
// Verify device is not RDMA capable
65+
rdmaCapable, exists := nonRdmaDevice.Attributes[consts.AttributeRDMACapable]
66+
Expect(exists).To(BeTrue())
67+
Expect(rdmaCapable.BoolValue).ToNot(BeNil())
68+
Expect(*rdmaCapable.BoolValue).To(BeFalse())
69+
70+
// Test the conditional logic that determines if RDMA preparation should occur
71+
// This replicates the production code condition:
72+
// if rdmaCapableAttr, ok := deviceInfo.Attributes[consts.AttributeRDMACapable]; ok && rdmaCapableAttr.BoolValue != nil && *rdmaCapableAttr.BoolValue
73+
shouldPrepareRDMA := exists && rdmaCapable.BoolValue != nil && *rdmaCapable.BoolValue
74+
Expect(shouldPrepareRDMA).To(BeFalse(), "RDMA preparation should be skipped for non-RDMA capable devices")
75+
76+
// When this condition is false, the production code never calls:
77+
// - GetRDMADeviceForPCI
78+
// - GetRDMACharDevices
79+
// This test verifies the condition evaluates correctly for non-RDMA devices
80+
})
81+
})
4782
})

pkg/host/host.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ type Interface interface {
114114
// RDMA device functions
115115
GetRDMADeviceForPCI(pciAddr string) ([]string, error)
116116
VerifyRDMACapability(pciAddr string) (bool, error)
117+
GetRDMACharDevices(rdmaDeviceName string) ([]string, error)
117118
}
118119

119120
// Host provides unified host system functionality for SR-IOV, PCI operations, and driver management
@@ -783,6 +784,21 @@ func (h *Host) EnsureVhostModulesLoaded() error {
783784
// GetRDMADeviceForPCI returns the RDMA device names associated with a PCI address
784785
// Uses the rdmamap library from Mellanox for RDMA device detection
785786
func (h *Host) GetRDMADeviceForPCI(pciAddr string) ([]string, error) {
787+
// Validate input
788+
if pciAddr == "" {
789+
return nil, fmt.Errorf("pciAddr cannot be empty")
790+
}
791+
// Validate PCI address format (should be DDDD:BB:DD.F)
792+
parts := strings.Split(pciAddr, ":")
793+
if len(parts) != 3 {
794+
return nil, fmt.Errorf("invalid PCI address format: %s (expected format: DDDD:BB:DD.F)", pciAddr)
795+
}
796+
// Validate the last part contains device.function format
797+
deviceFunc := strings.Split(parts[2], ".")
798+
if len(deviceFunc) != 2 {
799+
return nil, fmt.Errorf("invalid PCI address format: %s (expected format: DDDD:BB:DD.F)", pciAddr)
800+
}
801+
786802
h.log.V(2).Info("GetRDMADeviceForPCI(): getting RDMA devices for PCI address", "device", pciAddr)
787803

788804
// Use rdmaProvider to get RDMA devices for this PCI address
@@ -813,3 +829,27 @@ func (h *Host) VerifyRDMACapability(pciAddr string) (bool, error) {
813829

814830
return hasRDMA, nil
815831
}
832+
833+
// GetRDMACharDevices returns the character device paths for an RDMA device
834+
// These are the actual device nodes (e.g., /dev/infiniband/uverbs0) that need to be
835+
// exposed to containers for RDMA functionality
836+
func (h *Host) GetRDMACharDevices(rdmaDeviceName string) ([]string, error) {
837+
// Validate input
838+
if rdmaDeviceName == "" {
839+
return nil, fmt.Errorf("rdmaDeviceName cannot be empty")
840+
}
841+
842+
h.log.V(2).Info("GetRDMACharDevices(): getting character devices for RDMA device", "rdmaDevice", rdmaDeviceName)
843+
844+
// Use rdmaProvider to get character devices for this RDMA device
845+
charDevices := h.rdmaProvider.GetRdmaCharDevices(rdmaDeviceName)
846+
847+
if len(charDevices) == 0 {
848+
h.log.V(2).Info("GetRDMACharDevices(): no character devices found", "rdmaDevice", rdmaDeviceName)
849+
return nil, nil
850+
}
851+
852+
h.log.Info("GetRDMACharDevices(): found character devices",
853+
"rdmaDevice", rdmaDeviceName, "charDevices", charDevices)
854+
return charDevices, nil
855+
}

pkg/host/host_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,41 @@ vhost_net 32768 1 tun, Live 0xffffffffa0456000`),
624624
Expect(err).NotTo(HaveOccurred())
625625
Expect(devices).To(BeNil())
626626
})
627+
628+
It("should return error when PCI address is empty", func() {
629+
devices, err := hostImpl.GetRDMADeviceForPCI("")
630+
Expect(err).To(HaveOccurred())
631+
Expect(err.Error()).To(ContainSubstring("pciAddr cannot be empty"))
632+
Expect(devices).To(BeNil())
633+
})
634+
635+
It("should return error when PCI address format is invalid - no colons", func() {
636+
devices, err := hostImpl.GetRDMADeviceForPCI("invalid")
637+
Expect(err).To(HaveOccurred())
638+
Expect(err.Error()).To(ContainSubstring("invalid PCI address format"))
639+
Expect(devices).To(BeNil())
640+
})
641+
642+
It("should return error when PCI address format is invalid - too few parts", func() {
643+
devices, err := hostImpl.GetRDMADeviceForPCI("0000:08")
644+
Expect(err).To(HaveOccurred())
645+
Expect(err.Error()).To(ContainSubstring("invalid PCI address format"))
646+
Expect(devices).To(BeNil())
647+
})
648+
649+
It("should return error when PCI address format is invalid - missing device.function dot", func() {
650+
devices, err := hostImpl.GetRDMADeviceForPCI("0000:08:00")
651+
Expect(err).To(HaveOccurred())
652+
Expect(err.Error()).To(ContainSubstring("invalid PCI address format"))
653+
Expect(devices).To(BeNil())
654+
})
655+
656+
It("should return error when PCI address format is invalid - too many parts", func() {
657+
devices, err := hostImpl.GetRDMADeviceForPCI("0000:08:00:01.5")
658+
Expect(err).To(HaveOccurred())
659+
Expect(err.Error()).To(ContainSubstring("invalid PCI address format"))
660+
Expect(devices).To(BeNil())
661+
})
627662
})
628663

629664
Context("VerifyRDMACapability", func() {
@@ -660,5 +695,94 @@ vhost_net 32768 1 tun, Live 0xffffffffa0456000`),
660695
Expect(capable).To(BeFalse())
661696
})
662697
})
698+
699+
Context("GetRDMACharDevices", func() {
700+
It("should return character devices when they exist", func() {
701+
mockRdmaProvider.EXPECT().
702+
GetRdmaCharDevices("mlx5_1").
703+
Return([]string{
704+
"/dev/infiniband/uverbs0",
705+
"/dev/infiniband/umad0",
706+
"/dev/infiniband/issm0",
707+
"/dev/infiniband/rdma_cm",
708+
}).
709+
Times(1)
710+
711+
charDevices, err := hostImpl.GetRDMACharDevices("mlx5_1")
712+
Expect(err).NotTo(HaveOccurred())
713+
Expect(charDevices).To(HaveLen(4))
714+
Expect(charDevices).To(ContainElements(
715+
"/dev/infiniband/uverbs0",
716+
"/dev/infiniband/umad0",
717+
"/dev/infiniband/issm0",
718+
"/dev/infiniband/rdma_cm",
719+
))
720+
})
721+
722+
It("should return subset of character devices when only some exist", func() {
723+
mockRdmaProvider.EXPECT().
724+
GetRdmaCharDevices("mlx5_2").
725+
Return([]string{
726+
"/dev/infiniband/uverbs1",
727+
"/dev/infiniband/rdma_cm",
728+
}).
729+
Times(1)
730+
731+
charDevices, err := hostImpl.GetRDMACharDevices("mlx5_2")
732+
Expect(err).NotTo(HaveOccurred())
733+
Expect(charDevices).To(HaveLen(2))
734+
Expect(charDevices).To(ContainElements(
735+
"/dev/infiniband/uverbs1",
736+
"/dev/infiniband/rdma_cm",
737+
))
738+
})
739+
740+
It("should return nil when no character devices exist", func() {
741+
mockRdmaProvider.EXPECT().
742+
GetRdmaCharDevices("mlx5_3").
743+
Return([]string{}).
744+
Times(1)
745+
746+
charDevices, err := hostImpl.GetRDMACharDevices("mlx5_3")
747+
Expect(err).NotTo(HaveOccurred())
748+
Expect(charDevices).To(BeNil())
749+
})
750+
751+
It("should handle multiple RDMA devices with different character devices", func() {
752+
// First RDMA device
753+
mockRdmaProvider.EXPECT().
754+
GetRdmaCharDevices("mlx5_1").
755+
Return([]string{
756+
"/dev/infiniband/uverbs0",
757+
"/dev/infiniband/umad0",
758+
}).
759+
Times(1)
760+
761+
// Second RDMA device
762+
mockRdmaProvider.EXPECT().
763+
GetRdmaCharDevices("mlx5_2").
764+
Return([]string{
765+
"/dev/infiniband/uverbs1",
766+
"/dev/infiniband/umad1",
767+
}).
768+
Times(1)
769+
770+
charDevices1, err := hostImpl.GetRDMACharDevices("mlx5_1")
771+
Expect(err).NotTo(HaveOccurred())
772+
Expect(charDevices1).To(HaveLen(2))
773+
774+
charDevices2, err := hostImpl.GetRDMACharDevices("mlx5_2")
775+
Expect(err).NotTo(HaveOccurred())
776+
Expect(charDevices2).To(HaveLen(2))
777+
Expect(charDevices2).NotTo(Equal(charDevices1))
778+
})
779+
780+
It("should return error when RDMA device name is empty", func() {
781+
charDevices, err := hostImpl.GetRDMACharDevices("")
782+
Expect(err).To(HaveOccurred())
783+
Expect(err.Error()).To(ContainSubstring("rdmaDeviceName cannot be empty"))
784+
Expect(charDevices).To(BeNil())
785+
})
786+
})
663787
})
664788
})

pkg/host/mock/mock_host.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/host/mock/mock_rdma_provider.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/host/rdma_provider.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
//go:generate mockgen -destination mock/mock_rdma_provider.go -source rdma_provider.go
2727
type RdmaProvider interface {
2828
GetRdmaDevicesForPcidev(pciAddr string) []string
29+
GetRdmaCharDevices(rdmaDeviceName string) []string
2930
}
3031

3132
type defaultRdmaProvider struct{}
@@ -35,6 +36,11 @@ func (defaultRdmaProvider) GetRdmaDevicesForPcidev(pciAddr string) []string {
3536
return rdmamap.GetRdmaDevicesForPcidev(pciAddr)
3637
}
3738

39+
// GetRdmaCharDevices returns character device paths for an RDMA device
40+
func (defaultRdmaProvider) GetRdmaCharDevices(rdmaDeviceName string) []string {
41+
return rdmamap.GetRdmaCharDevices(rdmaDeviceName)
42+
}
43+
3844
// newRdmaProvider creates a new default RDMA provider
3945
func newRdmaProvider() RdmaProvider {
4046
return &defaultRdmaProvider{}

0 commit comments

Comments
 (0)