Skip to content

Commit

Permalink
refactor: add device controller and test cases
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Yu <[email protected]>
  • Loading branch information
Yu-Jack committed May 27, 2024
1 parent d2ddd05 commit ae394f4
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 55 deletions.
3 changes: 2 additions & 1 deletion pkg/controller/usbdevice/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/rancher/wrangler/pkg/relatedresource"

"github.com/harvester/pcidevices/pkg/config"
"github.com/harvester/pcidevices/pkg/deviceplugins"
)

const (
Expand All @@ -20,7 +21,7 @@ func Register(ctx context.Context, management *config.FactoryManager) error {
virtClient := management.KubevirtFactory.Kubevirt().V1().KubeVirt()

handler := NewHandler(usbDeviceCtrl, usbDeviceClaimCtrl)
usbDeviceClaimController := NewClaimHandler(usbDeviceCtrl.Cache(), usbDeviceClaimCtrl, usbDeviceCtrl, virtClient)
usbDeviceClaimController := NewClaimHandler(usbDeviceCtrl.Cache(), usbDeviceClaimCtrl, usbDeviceCtrl, virtClient, deviceplugins.NewUSBDevicePlugin)

usbDeviceClaimCtrl.OnChange(ctx, "usbClaimClient-device-claim", usbDeviceClaimController.OnUSBDeviceClaimChanged)
usbDeviceClaimCtrl.OnRemove(ctx, "usbClaimClient-device-claim-remove", usbDeviceClaimController.OnRemove)
Expand Down
94 changes: 71 additions & 23 deletions pkg/controller/usbdevice/usbdevice_claim_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,43 @@ import (
ctlkubevirtv1 "github.com/harvester/pcidevices/pkg/generated/controllers/kubevirt.io/v1"
)

var (
discoverAllowedUSBDevices = deviceplugins.DiscoverAllowedUSBDevices
)

type ClaimHandler struct {
usbClaimClient ctldevicerv1beta1.USBDeviceClaimClient
usbClient ctldevicerv1beta1.USBDeviceClient
virtClient ctlkubevirtv1.KubeVirtClient
lock *sync.Mutex
usbDeviceCache ctldevicerv1beta1.USBDeviceCache
devicePlugin map[string]*deviceplugins.USBDevicePlugin
usbClaimClient ctldevicerv1beta1.USBDeviceClaimClient
usbClient ctldevicerv1beta1.USBDeviceClient
virtClient ctlkubevirtv1.KubeVirtClient
lock *sync.Mutex
usbDeviceCache ctldevicerv1beta1.USBDeviceCache
devicePlugin map[string]*deviceController
devicePluginConvertor devicePluginConvertor
}

func NewClaimHandler(usbDeviceCache ctldevicerv1beta1.USBDeviceCache, usbClaimClient ctldevicerv1beta1.USBDeviceClaimClient, usbClient ctldevicerv1beta1.USBDeviceClient, virtClient ctlkubevirtv1.KubeVirtClient) *ClaimHandler {
type deviceController struct {
device deviceplugins.USBDevicePluginInterface
stop chan struct{}
started bool
}

type devicePluginConvertor func(resourceName string, devices []*deviceplugins.PluginDevices) deviceplugins.USBDevicePluginInterface

func NewClaimHandler(
usbDeviceCache ctldevicerv1beta1.USBDeviceCache,
usbClaimClient ctldevicerv1beta1.USBDeviceClaimClient,
usbClient ctldevicerv1beta1.USBDeviceClient,
virtClient ctlkubevirtv1.KubeVirtClient,
devicePluginHelper devicePluginConvertor,
) *ClaimHandler {
return &ClaimHandler{
usbDeviceCache: usbDeviceCache,
usbClaimClient: usbClaimClient,
usbClient: usbClient,
virtClient: virtClient,
lock: &sync.Mutex{},
devicePlugin: map[string]*deviceplugins.USBDevicePlugin{},
usbDeviceCache: usbDeviceCache,
usbClaimClient: usbClaimClient,
usbClient: usbClient,
virtClient: virtClient,
lock: &sync.Mutex{},
devicePlugin: map[string]*deviceController{},
devicePluginConvertor: devicePluginHelper,
}
}

Expand Down Expand Up @@ -76,12 +96,15 @@ func (h *ClaimHandler) OnUSBDeviceClaimChanged(_ string, usbDeviceClaim *v1beta1
return usbDeviceClaim, nil
}

pluginDevices := deviceplugins.DiscoverAllowedUSBDevices(newVirt.Spec.Configuration.PermittedHostDevices.USB)
pluginDevices := discoverAllowedUSBDevices(newVirt.Spec.Configuration.PermittedHostDevices.USB)

if pluginDevice := h.findDevicePlugin(pluginDevices, usbDevice); pluginDevice != nil {
usbDevicePlugin := deviceplugins.NewUSBDevicePlugin(usbDevice.Status.ResourceName, []*deviceplugins.PluginDevices{pluginDevice})
h.devicePlugin[usbDeviceClaim.Name] = usbDevicePlugin
go h.startDevicePlugin(usbDevicePlugin)
usbDevicePlugin := h.devicePluginConvertor(usbDevice.Status.ResourceName, []*deviceplugins.PluginDevices{pluginDevice})
deviceHan := &deviceController{
device: usbDevicePlugin,
}
h.devicePlugin[usbDeviceClaim.Name] = deviceHan
h.startDevicePlugin(deviceHan)
}

usbDeviceCp := usbDevice.DeepCopy()
Expand All @@ -98,12 +121,37 @@ func (h *ClaimHandler) OnUSBDeviceClaimChanged(_ string, usbDeviceClaim *v1beta1
return h.usbClaimClient.UpdateStatus(usbDeviceClaimCp)
}

func (h *ClaimHandler) startDevicePlugin(usbDevicePlugin *deviceplugins.USBDevicePlugin) {
stop := make(chan struct{})
if err := usbDevicePlugin.Start(stop); err != nil {
func (h *ClaimHandler) startDevicePlugin(deviceHan *deviceController) {
if deviceHan.started {
return
}

deviceHan.stop = make(chan struct{})

go func() {
if err := deviceHan.device.Start(deviceHan.stop); err != nil {
logrus.Errorf("failed to start device plugin: %v", err)
}
<-deviceHan.stop
}()

deviceHan.started = true
}

func (h *ClaimHandler) stopDevicePlugin(deviceHan *deviceController) error {
if !deviceHan.started {
return nil
}

close(deviceHan.stop)
deviceHan.started = false

if err := deviceHan.device.StopDevicePlugin(); err != nil {
logrus.Errorf("failed to start device plugin: %v", err)
return err
}
<-stop

return nil
}

func (h *ClaimHandler) findDevicePlugin(pluginDevices map[string][]*deviceplugins.PluginDevices, usbDevice *v1beta1.USBDevice) *deviceplugins.PluginDevices {
Expand Down Expand Up @@ -173,8 +221,8 @@ func (h *ClaimHandler) OnRemove(_ string, claim *v1beta1.USBDeviceClaim) (*v1bet
}
}

if dp, ok := h.devicePlugin[claim.Name]; ok {
if err := dp.StopDevicePlugin(); err != nil {
if handler, ok := h.devicePlugin[claim.Name]; ok {
if err := h.stopDevicePlugin(handler); err != nil {
return claim, err
}

Expand Down
60 changes: 30 additions & 30 deletions pkg/controller/usbdevice/usbdevice_claim_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ import (
kubevirtv1 "kubevirt.io/api/core/v1"

"github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
"github.com/harvester/pcidevices/pkg/deviceplugins"
"github.com/harvester/pcidevices/pkg/generated/clientset/versioned/fake"
"github.com/harvester/pcidevices/pkg/util/fakeclients"
)

type mockUSBDevicePlugin struct{}

func (m *mockUSBDevicePlugin) Start(stop <-chan struct{}) error { return nil }
func (m *mockUSBDevicePlugin) StopDevicePlugin() error { return nil }

var (
mockUsbDevice1 = &v1beta1.USBDevice{
ObjectMeta: metav1.ObjectMeta{
Expand Down Expand Up @@ -55,18 +61,39 @@ var (
NodeName: "test-node",
},
}
mockUSBDevicePluginHelper = func(resourceName string, pluginDevices []*deviceplugins.PluginDevices) deviceplugins.USBDevicePluginInterface {
return &mockUSBDevicePlugin{}
}
)

func Test_OnUSBDeviceClaimChanged(t *testing.T) {
client := fake.NewSimpleClientset(mockUsbDevice1, mockUsbDeviceClaim1, mockKubeVirt)
discoverAllowedUSBDevices = func(usbs []kubevirtv1.USBHostDevice) map[string][]*deviceplugins.PluginDevices {
m := map[string][]*deviceplugins.PluginDevices{}
m[mockUsbDevice1.Status.ResourceName] = []*deviceplugins.PluginDevices{
{
ID: "test",
Devices: []*deviceplugins.USBDevice{
{
Vendor: 2385,
Product: 5734,
DevicePath: "/dev/bus/usb/001/002",
},
},
},
}
return m
}

handler := NewClaimHandler(
fakeclients.USBDeviceCache(client.DevicesV1beta1().USBDevices),
fakeclients.USBDeviceClaimsClient(client.DevicesV1beta1().USBDeviceClaims),
fakeclients.USBDevicesClient(client.DevicesV1beta1().USBDevices),
fakeclients.KubeVirtClient(client.KubevirtV1().KubeVirts),
mockUSBDevicePluginHelper,
)

// Test claim created
_, err := handler.OnUSBDeviceClaimChanged("", mockUsbDeviceClaim1)
assert.NoError(t, err)

Expand All @@ -93,38 +120,11 @@ func Test_OnUSBDeviceClaimChanged(t *testing.T) {
usbDevice, err := client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{})
assert.NoError(t, err)
assert.Equal(t, true, usbDevice.Status.Enabled)
}

func Test_OnUSBDeviceClaimRemove(t *testing.T) {
mockUsbDevice1.Status.Enabled = true
mockKubeVirt.Spec.Configuration = kubevirtv1.KubeVirtConfiguration{
PermittedHostDevices: &kubevirtv1.PermittedHostDevices{
USB: []kubevirtv1.USBHostDevice{
{
ResourceName: mockUsbDevice1.Status.ResourceName,
ExternalResourceProvider: true,
Selectors: []kubevirtv1.USBSelector{
{
Vendor: "0951",
Product: "1666",
},
},
},
},
},
}
client := fake.NewSimpleClientset(mockUsbDevice1, mockKubeVirt)

handler := NewClaimHandler(
fakeclients.USBDeviceCache(client.DevicesV1beta1().USBDevices),
fakeclients.USBDeviceClaimsClient(client.DevicesV1beta1().USBDeviceClaims),
fakeclients.USBDevicesClient(client.DevicesV1beta1().USBDevices),
fakeclients.KubeVirtClient(client.KubevirtV1().KubeVirts),
)

_, err := handler.OnRemove("", mockUsbDeviceClaim1)
// Test claim removed
_, err = handler.OnRemove("", mockUsbDeviceClaim1)
assert.NoError(t, err)
usbDevice, err := client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{})
usbDevice, err = client.DevicesV1beta1().USBDevices().Get(context.Background(), mockUsbDevice1.Name, metav1.GetOptions{})
assert.NoError(t, err)
assert.Equal(t, false, usbDevice.Status.Enabled)
kubeVirt, err := client.KubevirtV1().KubeVirts(mockKubeVirt.Namespace).Get(context.Background(), mockKubeVirt.Name, metav1.GetOptions{})
Expand Down
7 changes: 6 additions & 1 deletion pkg/deviceplugins/usb_device_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ type USBDevice struct {
PCIAddress string
}

type USBDevicePluginInterface interface {
Start(stop <-chan struct{}) error
StopDevicePlugin() error
}

func (dev *USBDevice) GetID() string {
return fmt.Sprintf("%04x:%04x-%02d:%02d", dev.Vendor, dev.Product, dev.Bus, dev.DeviceNumber)
}
Expand Down Expand Up @@ -447,7 +452,7 @@ func (plugin *USBDevicePlugin) PreStartContainer(context.Context, *pluginapi.Pre
return &pluginapi.PreStartContainerResponse{}, nil
}

func NewUSBDevicePlugin(resourceName string, pluginDevices []*PluginDevices) *USBDevicePlugin {
func NewUSBDevicePlugin(resourceName string, pluginDevices []*PluginDevices) USBDevicePluginInterface {
s := strings.Split(resourceName, "/")
resourceID := s[0]
if len(s) > 1 {
Expand Down

0 comments on commit ae394f4

Please sign in to comment.