Skip to content

Commit 144ef0a

Browse files
committed
feat: Add RDMA device support with CDI integration
Implements RDMA capability detection and device exposure using the Mellanox rdmamap library. Adds discovery of RDMA devices (mlx5_*), protocol detection (RoCE/InfiniBand/iWARP), and automatic injection of character devices (/dev/infiniband/*) into containers via CDI. Key changes: - Add RdmaProvider interface with rdmamap integration - Expose RDMA character devices (uverbs, umad, issm, rdma_cm) - Set environment variables for device paths and names - Add comprehensive unit tests for RDMA functionality Matches feature parity with sriov-network-device-plugin for RDMA. Tested on Mellanox ConnectX hardware with RoCE. Signed-off-by: Fred Rolland <frolland@nvidia.com>
1 parent bae0300 commit 144ef0a

File tree

7 files changed

+495
-0
lines changed

7 files changed

+495
-0
lines changed

pkg/devicestate/state.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,16 @@ func (s *Manager) applyConfigOnDevice(ctx context.Context, ifNameIndex *int, cla
215215
})
216216
}
217217

218+
// If device is RDMA capable, add RDMA character devices
219+
if rdmaCapableAttr, ok := deviceInfo.Attributes[consts.AttributeRDMACapable]; ok && rdmaCapableAttr.BoolValue != nil && *rdmaCapableAttr.BoolValue {
220+
rdmaDeviceNodes, rdmaEnvs, err := s.handleRDMADevice(ctx, pciAddress, result.Device)
221+
if err != nil {
222+
return nil, fmt.Errorf("error handling RDMA device: %w", err)
223+
}
224+
deviceNodes = append(deviceNodes, rdmaDeviceNodes...)
225+
envs = append(envs, rdmaEnvs...)
226+
}
227+
218228
edits := &cdispec.ContainerEdits{
219229
Env: envs,
220230
DeviceNodes: deviceNodes,
@@ -254,6 +264,79 @@ func (s *Manager) applyConfigOnDevice(ctx context.Context, ifNameIndex *int, cla
254264
return preparedDevice, nil
255265
}
256266

267+
// handleRDMADevice handles RDMA device configuration and returns device nodes, environment variables, or an error
268+
func (s *Manager) handleRDMADevice(ctx context.Context, pciAddress, deviceName string) ([]*cdispec.DeviceNode, []string, error) {
269+
logger := klog.FromContext(ctx).WithName("handleRDMADevice")
270+
var deviceNodes []*cdispec.DeviceNode
271+
var envs []string
272+
273+
rdmaDevices, err := host.GetHelpers().GetRDMADeviceForPCI(pciAddress)
274+
if err != nil {
275+
logger.Error(err, "Failed to get RDMA devices for PCI address",
276+
"device", pciAddress)
277+
return nil, nil, err
278+
}
279+
280+
if len(rdmaDevices) == 0 {
281+
logger.V(2).Info("No RDMA devices found for PCI address", "device", pciAddress)
282+
return nil, nil, fmt.Errorf("no RDMA devices found for PCI address %s", pciAddress)
283+
}
284+
285+
logger.V(2).Info("Device is RDMA capable, adding RDMA character devices",
286+
"device", pciAddress, "rdmaDevices", rdmaDevices)
287+
288+
for _, rdmaDevice := range rdmaDevices {
289+
// Get character devices for this RDMA device
290+
charDevices, err := host.GetHelpers().GetRDMACharDevices(rdmaDevice)
291+
if err != nil {
292+
logger.Error(err, "Failed to get RDMA character devices, skipping",
293+
"device", pciAddress, "rdmaDevice", rdmaDevice)
294+
return nil, nil, err
295+
}
296+
297+
if len(charDevices) == 0 {
298+
logger.V(2).Info("No RDMA character devices found",
299+
"device", pciAddress, "rdmaDevice", rdmaDevice)
300+
return nil, nil, fmt.Errorf("no RDMA character devices found for RDMA device %s (PCI: %s)", rdmaDevice, pciAddress)
301+
}
302+
303+
// Use RDMA device name in env var key to support multiple RDMA devices
304+
rdmaDeviceSanitized := strings.ReplaceAll(rdmaDevice, "_", "")
305+
devicePrefix := strings.ReplaceAll(deviceName, "-", "_")
306+
307+
// Add each character device to the CDI spec
308+
for _, charDev := range charDevices {
309+
deviceNodes = append(deviceNodes, &cdispec.DeviceNode{
310+
Path: charDev,
311+
HostPath: charDev,
312+
Type: "c", // character device
313+
})
314+
315+
// Add environment variable for each character device type
316+
// Include RDMA device name to avoid collisions with multiple RDMA devices
317+
switch {
318+
case strings.Contains(charDev, "uverbs"):
319+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_UVERBS=%s", devicePrefix, rdmaDeviceSanitized, charDev))
320+
case strings.Contains(charDev, "umad"):
321+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_UMAD=%s", devicePrefix, rdmaDeviceSanitized, charDev))
322+
case strings.Contains(charDev, "issm"):
323+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_ISSM=%s", devicePrefix, rdmaDeviceSanitized, charDev))
324+
case strings.Contains(charDev, "rdma_cm"):
325+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_CM=%s", devicePrefix, rdmaDeviceSanitized, charDev))
326+
}
327+
}
328+
329+
logger.Info("Added RDMA character devices for device",
330+
"device", pciAddress, "rdmaDevice", rdmaDevice, "charDevices", charDevices, "envs", envs)
331+
332+
// Add RDMA device name to environment variables
333+
envs = append(envs, fmt.Sprintf("SRIOVNETWORK_%s_%s_RDMA_DEVICE=%s",
334+
devicePrefix, rdmaDeviceSanitized, rdmaDevice))
335+
}
336+
337+
return deviceNodes, envs, nil
338+
}
339+
257340
func (s *Manager) getNetAttachDefRawConfig(ctx context.Context, namespace string, netAttachDefName string) (string, error) {
258341
// Get the net attach def information
259342
netAttachDef := &netattdefv1.NetworkAttachmentDefinition{}

pkg/devicestate/state_test.go

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@ package devicestate
22

33
import (
44
"context"
5+
"fmt"
56

67
. "github.com/onsi/ginkgo/v2"
78
. "github.com/onsi/gomega"
9+
"go.uber.org/mock/gomock"
10+
"k8s.io/utils/ptr"
811

912
resourceapi "k8s.io/api/resource/v1"
1013

1114
"github.com/k8snetworkplumbingwg/dra-driver-sriov/pkg/consts"
15+
"github.com/k8snetworkplumbingwg/dra-driver-sriov/pkg/host"
16+
mock_host "github.com/k8snetworkplumbingwg/dra-driver-sriov/pkg/host/mock"
1217
)
1318

1419
var _ = Describe("Manager", func() {
@@ -45,4 +50,212 @@ var _ = Describe("Manager", func() {
4550
Expect(exists).To(BeFalse())
4651
})
4752
})
53+
54+
Context("RDMA Device Preparation", func() {
55+
It("should skip RDMA preparation when device is not RDMA capable", func() {
56+
// Create device without RDMA capability
57+
nonRdmaDevice := &resourceapi.Device{
58+
Name: "0000-08-00-1",
59+
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
60+
consts.AttributePciAddress: {
61+
StringValue: ptr.To("0000:08:00.1"),
62+
},
63+
consts.AttributeRDMACapable: {
64+
BoolValue: ptr.To(false),
65+
},
66+
},
67+
}
68+
69+
// Verify device is not RDMA capable
70+
rdmaCapable, exists := nonRdmaDevice.Attributes[consts.AttributeRDMACapable]
71+
Expect(exists).To(BeTrue())
72+
Expect(rdmaCapable.BoolValue).ToNot(BeNil())
73+
Expect(*rdmaCapable.BoolValue).To(BeFalse())
74+
75+
// Test the conditional logic that determines if RDMA preparation should occur
76+
// This replicates the production code condition:
77+
// if rdmaCapableAttr, ok := deviceInfo.Attributes[consts.AttributeRDMACapable]; ok && rdmaCapableAttr.BoolValue != nil && *rdmaCapableAttr.BoolValue
78+
shouldPrepareRDMA := exists && rdmaCapable.BoolValue != nil && *rdmaCapable.BoolValue
79+
Expect(shouldPrepareRDMA).To(BeFalse(), "RDMA preparation should be skipped for non-RDMA capable devices")
80+
81+
// When this condition is false, the production code never calls:
82+
// - GetRDMADeviceForPCI
83+
// - GetRDMACharDevices
84+
// This test verifies the condition evaluates correctly for non-RDMA devices
85+
})
86+
})
87+
88+
Context("handleRDMADevice", func() {
89+
var (
90+
mockCtrl *gomock.Controller
91+
mockHost *mock_host.MockInterface
92+
origHelpers host.Interface
93+
manager *Manager
94+
)
95+
96+
BeforeEach(func() {
97+
mockCtrl = gomock.NewController(GinkgoT())
98+
mockHost = mock_host.NewMockInterface(mockCtrl)
99+
// Save original helpers and replace with mock
100+
_ = host.GetHelpers()
101+
origHelpers = host.Helpers
102+
host.Helpers = mockHost
103+
104+
manager = &Manager{}
105+
})
106+
107+
AfterEach(func() {
108+
// Restore original helpers
109+
host.Helpers = origHelpers
110+
mockCtrl.Finish()
111+
})
112+
113+
It("should return device nodes and environment variables for RDMA device", func() {
114+
pciAddress := "0000:08:00.1"
115+
deviceName := "device-1"
116+
rdmaDeviceName := "mlx5_0"
117+
118+
// Mock GetRDMADeviceForPCI to return one RDMA device
119+
mockHost.EXPECT().GetRDMADeviceForPCI(pciAddress).Return([]string{rdmaDeviceName}, nil)
120+
121+
// Mock GetRDMACharDevices to return various character devices
122+
mockHost.EXPECT().GetRDMACharDevices(rdmaDeviceName).Return([]string{
123+
"/dev/infiniband/uverbs0",
124+
"/dev/infiniband/umad0",
125+
"/dev/infiniband/issm0",
126+
"/dev/infiniband/rdma_cm",
127+
}, nil)
128+
129+
// Call the function
130+
deviceNodes, envs, err := manager.handleRDMADevice(context.Background(), pciAddress, deviceName)
131+
132+
// Verify no error
133+
Expect(err).ToNot(HaveOccurred())
134+
135+
// Verify device nodes
136+
Expect(deviceNodes).To(HaveLen(4))
137+
Expect(deviceNodes[0].Path).To(Equal("/dev/infiniband/uverbs0"))
138+
Expect(deviceNodes[0].HostPath).To(Equal("/dev/infiniband/uverbs0"))
139+
Expect(deviceNodes[0].Type).To(Equal("c"))
140+
Expect(deviceNodes[1].Path).To(Equal("/dev/infiniband/umad0"))
141+
Expect(deviceNodes[2].Path).To(Equal("/dev/infiniband/issm0"))
142+
Expect(deviceNodes[3].Path).To(Equal("/dev/infiniband/rdma_cm"))
143+
144+
// Verify environment variables
145+
Expect(envs).To(HaveLen(5))
146+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_UVERBS=/dev/infiniband/uverbs0"))
147+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_UMAD=/dev/infiniband/umad0"))
148+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_ISSM=/dev/infiniband/issm0"))
149+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_CM=/dev/infiniband/rdma_cm"))
150+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_DEVICE=mlx5_0"))
151+
})
152+
153+
It("should handle multiple RDMA devices", func() {
154+
pciAddress := "0000:08:00.1"
155+
deviceName := "device-1"
156+
157+
// Mock GetRDMADeviceForPCI to return two RDMA devices
158+
mockHost.EXPECT().GetRDMADeviceForPCI(pciAddress).Return([]string{"mlx5_0", "mlx5_1"}, nil)
159+
160+
// Mock GetRDMACharDevices for first device
161+
mockHost.EXPECT().GetRDMACharDevices("mlx5_0").Return([]string{"/dev/infiniband/uverbs0"}, nil)
162+
163+
// Mock GetRDMACharDevices for second device
164+
mockHost.EXPECT().GetRDMACharDevices("mlx5_1").Return([]string{"/dev/infiniband/uverbs1"}, nil)
165+
166+
// Call the function
167+
deviceNodes, envs, err := manager.handleRDMADevice(context.Background(), pciAddress, deviceName)
168+
169+
// Verify no error
170+
Expect(err).ToNot(HaveOccurred())
171+
172+
// Verify device nodes for both RDMA devices
173+
Expect(deviceNodes).To(HaveLen(2))
174+
Expect(deviceNodes[0].Path).To(Equal("/dev/infiniband/uverbs0"))
175+
Expect(deviceNodes[1].Path).To(Equal("/dev/infiniband/uverbs1"))
176+
177+
// Verify environment variables for both RDMA devices
178+
Expect(envs).To(HaveLen(4))
179+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_UVERBS=/dev/infiniband/uverbs0"))
180+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx50_RDMA_DEVICE=mlx5_0"))
181+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx51_RDMA_UVERBS=/dev/infiniband/uverbs1"))
182+
Expect(envs).To(ContainElement("SRIOVNETWORK_device_1_mlx51_RDMA_DEVICE=mlx5_1"))
183+
})
184+
185+
It("should return error when GetRDMADeviceForPCI fails", func() {
186+
pciAddress := "0000:08:00.1"
187+
deviceName := "device-1"
188+
189+
// Mock GetRDMADeviceForPCI to return an error
190+
mockHost.EXPECT().GetRDMADeviceForPCI(pciAddress).Return(nil, fmt.Errorf("failed to get RDMA devices"))
191+
192+
// Call the function
193+
deviceNodes, envs, err := manager.handleRDMADevice(context.Background(), pciAddress, deviceName)
194+
195+
// Verify error is returned
196+
Expect(err).To(HaveOccurred())
197+
Expect(err.Error()).To(ContainSubstring("failed to get RDMA devices"))
198+
Expect(deviceNodes).To(BeNil())
199+
Expect(envs).To(BeNil())
200+
})
201+
202+
It("should return error when no RDMA devices found", func() {
203+
pciAddress := "0000:08:00.1"
204+
deviceName := "device-1"
205+
206+
// Mock GetRDMADeviceForPCI to return empty list
207+
mockHost.EXPECT().GetRDMADeviceForPCI(pciAddress).Return([]string{}, nil)
208+
209+
// Call the function
210+
deviceNodes, envs, err := manager.handleRDMADevice(context.Background(), pciAddress, deviceName)
211+
212+
// Verify error is returned
213+
Expect(err).To(HaveOccurred())
214+
Expect(err.Error()).To(ContainSubstring("no RDMA devices found"))
215+
Expect(deviceNodes).To(BeNil())
216+
Expect(envs).To(BeNil())
217+
})
218+
219+
It("should return error when GetRDMACharDevices fails", func() {
220+
pciAddress := "0000:08:00.1"
221+
deviceName := "device-1"
222+
rdmaDeviceName := "mlx5_0"
223+
224+
// Mock GetRDMADeviceForPCI to return one RDMA device
225+
mockHost.EXPECT().GetRDMADeviceForPCI(pciAddress).Return([]string{rdmaDeviceName}, nil)
226+
227+
// Mock GetRDMACharDevices to return an error
228+
mockHost.EXPECT().GetRDMACharDevices(rdmaDeviceName).Return(nil, fmt.Errorf("failed to get char devices"))
229+
230+
// Call the function
231+
deviceNodes, envs, err := manager.handleRDMADevice(context.Background(), pciAddress, deviceName)
232+
233+
// Verify error is returned
234+
Expect(err).To(HaveOccurred())
235+
Expect(err.Error()).To(ContainSubstring("failed to get char devices"))
236+
Expect(deviceNodes).To(BeNil())
237+
Expect(envs).To(BeNil())
238+
})
239+
240+
It("should return error when no character devices found", func() {
241+
pciAddress := "0000:08:00.1"
242+
deviceName := "device-1"
243+
rdmaDeviceName := "mlx5_0"
244+
245+
// Mock GetRDMADeviceForPCI to return one RDMA device
246+
mockHost.EXPECT().GetRDMADeviceForPCI(pciAddress).Return([]string{rdmaDeviceName}, nil)
247+
248+
// Mock GetRDMACharDevices to return empty list
249+
mockHost.EXPECT().GetRDMACharDevices(rdmaDeviceName).Return([]string{}, nil)
250+
251+
// Call the function
252+
deviceNodes, envs, err := manager.handleRDMADevice(context.Background(), pciAddress, deviceName)
253+
254+
// Verify error is returned
255+
Expect(err).To(HaveOccurred())
256+
Expect(err.Error()).To(ContainSubstring("no RDMA character devices found"))
257+
Expect(deviceNodes).To(BeNil())
258+
Expect(envs).To(BeNil())
259+
})
260+
})
48261
})

pkg/host/host.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ type Interface interface {
114114
// RDMA device functions
115115
GetRDMADeviceForPCI(pciAddr string) ([]string, error)
116116
VerifyRDMACapability(pciAddr string) (bool, error)
117+
GetRDMACharDevices(rdmaDeviceName string) ([]string, error)
117118
}
118119

119120
// Host provides unified host system functionality for SR-IOV, PCI operations, and driver management
@@ -783,6 +784,21 @@ func (h *Host) EnsureVhostModulesLoaded() error {
783784
// GetRDMADeviceForPCI returns the RDMA device names associated with a PCI address
784785
// Uses the rdmamap library from Mellanox for RDMA device detection
785786
func (h *Host) GetRDMADeviceForPCI(pciAddr string) ([]string, error) {
787+
// Validate input
788+
if pciAddr == "" {
789+
return nil, fmt.Errorf("pciAddr cannot be empty")
790+
}
791+
// Validate PCI address format (should be DDDD:BB:DD.F)
792+
parts := strings.Split(pciAddr, ":")
793+
if len(parts) != 3 {
794+
return nil, fmt.Errorf("invalid PCI address format: %s (expected format: DDDD:BB:DD.F)", pciAddr)
795+
}
796+
// Validate the last part contains device.function format
797+
deviceFunc := strings.Split(parts[2], ".")
798+
if len(deviceFunc) != 2 {
799+
return nil, fmt.Errorf("invalid PCI address format: %s (expected format: DDDD:BB:DD.F)", pciAddr)
800+
}
801+
786802
h.log.V(2).Info("GetRDMADeviceForPCI(): getting RDMA devices for PCI address", "device", pciAddr)
787803

788804
// Use rdmaProvider to get RDMA devices for this PCI address
@@ -812,3 +828,27 @@ func (h *Host) VerifyRDMACapability(pciAddr string) (bool, error) {
812828

813829
return hasRDMA, nil
814830
}
831+
832+
// GetRDMACharDevices returns the character device paths for an RDMA device
833+
// These are the actual device nodes (e.g., /dev/infiniband/uverbs0) that need to be
834+
// exposed to containers for RDMA functionality
835+
func (h *Host) GetRDMACharDevices(rdmaDeviceName string) ([]string, error) {
836+
// Validate input
837+
if rdmaDeviceName == "" {
838+
return nil, fmt.Errorf("rdmaDeviceName cannot be empty")
839+
}
840+
841+
h.log.Info("GetRDMACharDevices(): getting character devices for RDMA device", "rdmaDevice", rdmaDeviceName)
842+
843+
// Use rdmaProvider to get character devices for this RDMA device
844+
charDevices := h.rdmaProvider.GetRdmaCharDevices(rdmaDeviceName)
845+
846+
if len(charDevices) == 0 {
847+
h.log.Info("GetRDMACharDevices(): no character devices found", "rdmaDevice", rdmaDeviceName)
848+
return nil, nil
849+
}
850+
851+
h.log.Info("GetRDMACharDevices(): found character devices",
852+
"rdmaDevice", rdmaDeviceName, "charDevices", charDevices)
853+
return charDevices, nil
854+
}

0 commit comments

Comments
 (0)