Skip to content

Add support for virtIO Socket #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/c/c.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const dtb = modsdf.dtb;
const sddf = modsdf.sddf;
const lionsos = modsdf.lionsos;
const Vmm = modsdf.Vmm;
const VmmVirtioSocketConnection = Vmm.VmmVirtioSocketConnection;
const SystemDescription = modsdf.sdf.SystemDescription;
const Pd = SystemDescription.ProtectionDomain;
const Irq = SystemDescription.Irq;
Expand Down Expand Up @@ -790,6 +791,24 @@ export fn sdfgen_vmm_add_virtio_mmio_net(c_vmm: *align(8) anyopaque, c_device: *
return true;
}

export fn sdfgen_vmm_virtio_socket_connection(c_sdf: *align(8) anyopaque, c_device: *align(8) anyopaque, c_vmm_a: *align(8) anyopaque, cid_a: u32, c_vmm_b: *align(8) anyopaque, cid_b: u32) *anyopaque {
const sdf: *SystemDescription = @ptrCast(c_sdf);
const device: *dtb.Node = @ptrCast(c_device);
const vmm_a: *Vmm = @ptrCast(c_vmm_a);
const vmm_b: *Vmm = @ptrCast(c_vmm_b);
const vsock = allocator.create(VmmVirtioSocketConnection) catch @panic("OOM");
vsock.* = VmmVirtioSocketConnection.init(allocator, sdf, device, vmm_a, cid_a, vmm_b, cid_b);

return vsock;
}

export fn sdfgen_vmm_virtio_socket_connection_connect(c_vmm_vsock: *align(8) anyopaque) bool {
const vsock: *VmmVirtioSocketConnection = @ptrCast(c_vmm_vsock);
vsock.connect() catch return false;

return true;
}

export fn sdfgen_vmm_connect(c_vmm: *align(8) anyopaque) bool {
const vmm: *Vmm = @ptrCast(c_vmm);
vmm.connect() catch return false;
Expand Down
11 changes: 10 additions & 1 deletion src/c/sdfgen.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,17 @@ bool sdfgen_sddf_gpu_serialise_config(void *system, char *output_dir);

/*** Virtual Machine Monitor ***/
void *sdfgen_vmm(void *sdf, void *vmm_pd, void *vm, char *name, void *dtb, bool one_to_one_ram);
bool sdfgen_vmm_add_passthrough_device(void *vmm, char *name, void *device);
bool sdfgen_vmm_add_passthrough_device(void *vmm, void *device);
bool sdfgen_vmm_add_passthrough_device_regions(void *vmm, void *device, void *regions, uint8_t num_regions);
bool sdfgen_vmm_add_passthrough_device_irqs(void *vmm, void *device, void *irqs, uint8_t num_irqs);
bool sdfgen_vmm_add_passthrough_irq(void *vmm, void *irq);
bool sdfgen_vmm_add_virtio_mmio_console(void *vmm, void *device, void *serial);
bool sdfgen_vmm_add_virtio_mmio_blk(void *vmm, void *device, void *blk, uint32_t partition);
bool sdfgen_vmm_add_virtio_mmio_net(void *vmm, void *device, void *net, void *copier, uint8_t mac_addr[6]);
void *sdfgen_vmm_virtio_socket_connection(void *sdf, void *device, void *vmm_a, uint32_t cid_a, void *vmm_b, uint32_t cid_b);
bool sdfgen_vmm_virtio_socket_connection_connect(void *vmm_vsock);
bool sdfgen_vmm_connect(void *vmm);
bool sdfgen_vmm_serialise_config(void *system, char *output_dir);

/*** LionsOS ***/

Expand Down
21 changes: 21 additions & 0 deletions src/python/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ class SddfStatus(IntEnum):
libsdfgen.sdfgen_vmm_add_virtio_mmio_blk.argtypes = [c_void_p, c_void_p, c_void_p, c_uint32]
libsdfgen.sdfgen_vmm_add_virtio_mmio_net.restype = c_bool
libsdfgen.sdfgen_vmm_add_virtio_mmio_net.argtypes = [c_void_p, c_void_p, c_void_p, c_void_p, c_char_p]
libsdfgen.sdfgen_vmm_virtio_socket_connection.restype = c_void_p
libsdfgen.sdfgen_vmm_virtio_socket_connection.argtypes = [c_void_p, c_void_p, c_void_p, c_uint32, c_void_p, c_uint32]
libsdfgen.sdfgen_vmm_virtio_socket_connection_connect.restype = c_bool
libsdfgen.sdfgen_vmm_virtio_socket_connection_connect.argtypes = [c_void_p]
libsdfgen.sdfgen_vmm_connect.restype = c_bool
libsdfgen.sdfgen_vmm_connect.argtypes = [c_void_p]
libsdfgen.sdfgen_vmm_serialise_config.restype = c_bool
Expand Down Expand Up @@ -1020,6 +1024,23 @@ def serialise_config(self, output_dir: str) -> bool:


class Vmm:
class VmmVirtioSocketConnection:
_obj: c_void_p

def __init__(
self,
sdf: SystemDescription,
device: DeviceTree.Node,
vmm_a: Vmm,
cid_a: int,
vmm_b: Vmm,
cid_b: int
):
self._obj = libsdfgen.sdfgen_vmm_virtio_socket_connection(sdf._obj, device._obj, vmm_a._obj, cid_a, vmm_b._obj, cid_b)

def connect(self) -> bool:
return libsdfgen.sdfgen_vmm_virtio_socket_connection_connect(self._obj)

_obj: c_void_p

def __init__(
Expand Down
200 changes: 181 additions & 19 deletions src/vmm.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const Arch = SystemDescription.Arch;
const Mr = SystemDescription.MemoryRegion;
const Pd = SystemDescription.ProtectionDomain;
const Irq = SystemDescription.Irq;
const Channel = SystemDescription.Channel;
const Map = SystemDescription.Map;
const Vm = SystemDescription.VirtualMachine;

Expand Down Expand Up @@ -48,23 +49,163 @@ pub const LinuxUioRegion = extern struct {
const MAX_IRQS: usize = 32;
const MAX_VCPUS: usize = 32;
const MAX_LINUX_UIO_REGIONS: usize = 16;
const MAX_VIRTIO_MMIO_DEVICES: usize = 32;
const MAX_VIRTIO_MMIO_DEVICES_PER_TYPE: usize = 4;

pub const VmmVirtioSocketConnection = struct {
allocator: Allocator,
sdf: *SystemDescription,
device: *dtb.Node,
vmm_a: *Self,
cid_a: u32,
vmm_b: *Self,
cid_b: u32,
connected: bool,
rx_buf_size: u32 = 0x1000,

pub fn init(allocator: Allocator, sdf: *SystemDescription, device: *dtb.Node, vmm_a: *Self, cid_a: u32, vmm_b: *Self, cid_b: u32) VmmVirtioSocketConnection {
return .{
.allocator = allocator,
.sdf = sdf,
.device = device,
.vmm_a = vmm_a,
.cid_a = cid_a,
.vmm_b = vmm_b,
.cid_b = cid_b,
.connected = false,
};
}

const Data = extern struct {
const VirtioMmioDevice = extern struct {
pub const Type = enum(u8) {
net = 1,
blk = 2,
console = 3,
sound = 25,
fn is_cid_valid(cid: u32) bool {
switch (cid) {
0, 1, 2, 0xffffffff => {
return false;
},
else => {
return true;
},
}
}

fn does_cid_exists(vmm: *Self, cid: u32) bool {
var i: usize = 0;
while (i < vmm.data.num_virtio_mmio_socket_devices) {
if (vmm.data.virtio_mmio_socket_devices[i].cid == cid) {
return true;
}
i += 1;
}
return false;
}

pub fn connect(vsock_connection: *VmmVirtioSocketConnection) !void {
const allocator = vsock_connection.allocator;
const sdf = vsock_connection.sdf;
const vsock_mmio_device = vsock_connection.device;
const vmm_a: *Self = vsock_connection.vmm_a;
const cid_a: u32 = vsock_connection.cid_a;
const vmm_b: *Self = vsock_connection.vmm_b;
const cid_b: u32 = vsock_connection.cid_b;

if (!is_cid_valid(cid_a)) {
log.err("error connecting virtIO socket connection between VMM '{s}' and '{s}': invalid CID '{d}'", .{vmm_a.vmm.name, vmm_b.vmm.name, cid_a});
return error.InvalidCid;
}
if (!is_cid_valid(cid_b)) {
log.err("error connecting virtIO socket connection between VMM '{s}' and '{s}': invalid CID '{d}'", .{vmm_a.vmm.name, vmm_b.vmm.name, cid_b});
return error.InvalidCid;
}

if (does_cid_exists(vmm_a, cid_a)) {
log.err("error connecting virtIO socket connection between VMM '{s}' and '{s}': CID '{d}' already exists on VMM '{s}'", .{vmm_a.vmm.name, vmm_b.vmm.name, cid_a, vmm_a.vmm.name});
return error.DuplicateCid;
}
if (does_cid_exists(vmm_a, cid_b)) {
log.err("error connecting virtIO socket connection between VMM '{s}' and '{s}': CID '{d}' already exists on VMM '{s}'", .{vmm_a.vmm.name, vmm_b.vmm.name, cid_b, vmm_a.vmm.name});
return error.DuplicateCid;
}
if (does_cid_exists(vmm_b, cid_b)) {
log.err("error connecting virtIO socket connection between VMM '{s}' and '{s}': CID '{d}' already exists on VMM '{s}'", .{vmm_a.vmm.name, vmm_b.vmm.name, cid_b, vmm_b.vmm.name});
return error.DuplicateCid;
}
if (does_cid_exists(vmm_b, cid_a)) {
log.err("error connecting virtIO socket connection between VMM '{s}' and '{s}': CID '{d}' already exists on VMM '{s}'", .{vmm_a.vmm.name, vmm_b.vmm.name, cid_a, vmm_b.vmm.name});
return error.DuplicateCid;
}

const ch_vsock = try Channel.create(vmm_a.vmm, vmm_b.vmm, .{});
vsock_connection.sdf.addChannel(ch_vsock);

const mr_a_rx_buf = Mr.create(allocator, fmt(allocator, "vsock_{s}_rx_{s}_tx", .{vmm_a.vmm.name, vmm_b.vmm.name}), vsock_connection.rx_buf_size, .{});
const mr_b_rx_buf = Mr.create(allocator, fmt(allocator, "vsock_{s}_tx_{s}_rx", .{vmm_a.vmm.name, vmm_b.vmm.name}), vsock_connection.rx_buf_size, .{});
sdf.addMemoryRegion(mr_a_rx_buf);
sdf.addMemoryRegion(mr_b_rx_buf);

const map_a_rx_buf = Map.create(mr_a_rx_buf, vmm_a.vmm.getMapVaddr(&mr_a_rx_buf), .rw, .{});
vmm_a.vmm.addMap(map_a_rx_buf);
const map_a_tx_buf = Map.create(mr_b_rx_buf, vmm_a.vmm.getMapVaddr(&mr_b_rx_buf), .rw, .{});
vmm_a.vmm.addMap(map_a_tx_buf);

const map_b_rx_buf = Map.create(mr_b_rx_buf, vmm_b.vmm.getMapVaddr(&mr_b_rx_buf), .rw, .{});
vmm_b.vmm.addMap(map_b_rx_buf);
const map_b_tx_buf = Map.create(mr_a_rx_buf, vmm_b.vmm.getMapVaddr(&mr_a_rx_buf), .rw, .{});
vmm_b.vmm.addMap(map_b_tx_buf);

vmm_a.data.virtio_mmio_socket_devices[vmm_a.data.num_virtio_mmio_socket_devices] = .{
.regs = try vmm_a.parseVirtioMmioDeviceRegs(vsock_mmio_device),
.cid = cid_a,
.shared_buffer_size = vsock_connection.rx_buf_size,
.buffer_our = map_a_rx_buf.vaddr,
.buffer_peer = map_a_tx_buf.vaddr,
.peer_channel = ch_vsock.pd_a_id,
};
vmm_a.data.num_virtio_mmio_socket_devices += 1;

vmm_b.data.virtio_mmio_socket_devices[vmm_b.data.num_virtio_mmio_socket_devices] = .{
.regs = try vmm_b.parseVirtioMmioDeviceRegs(vsock_mmio_device),
.cid = cid_b,
.shared_buffer_size = vsock_connection.rx_buf_size,
.buffer_our = map_b_rx_buf.vaddr,
.buffer_peer = map_b_tx_buf.vaddr,
.peer_channel = ch_vsock.pd_b_id,
};
vmm_b.data.num_virtio_mmio_socket_devices += 1;

vsock_connection.connected = true;
}
};

type: u8,
const Data = extern struct {
const VirtioMmioDeviceRegs = extern struct {
addr: u64,
size: u32,
irq: u32,
};

const VirtioMmioConsoleDevice = extern struct {
regs: VirtioMmioDeviceRegs,
};

const VirtioMmioBlockDevice = extern struct {
regs: VirtioMmioDeviceRegs,
};

const VirtioMmioNetDevice = extern struct {
regs: VirtioMmioDeviceRegs,
};

const VirtioMmioSocketDevice = extern struct {
regs: VirtioMmioDeviceRegs,
cid: u32,
shared_buffer_size: u32,
buffer_our: u64,
buffer_peer: u64,
peer_channel: u32,
};

const VirtioMmioSoundDevice = extern struct {
regs: VirtioMmioDeviceRegs,
};

const Irq = extern struct {
id: u8,
irq: u32,
Expand All @@ -83,8 +224,22 @@ const Data = extern struct {
irqs: [MAX_IRQS]Data.Irq,
num_vcpus: u8,
vcpus: [MAX_VCPUS]Vcpu,
num_virtio_mmio_devices: u8,
virtio_mmio_devices: [MAX_VIRTIO_MMIO_DEVICES]VirtioMmioDevice,

num_virtio_mmio_console_devices: u8,
virtio_mmio_console_devices: [MAX_VIRTIO_MMIO_DEVICES_PER_TYPE]VirtioMmioConsoleDevice,

num_virtio_mmio_block_devices: u8,
virtio_mmio_block_devices: [MAX_VIRTIO_MMIO_DEVICES_PER_TYPE]VirtioMmioBlockDevice,

num_virtio_mmio_net_devices: u8,
virtio_mmio_net_devices: [MAX_VIRTIO_MMIO_DEVICES_PER_TYPE]VirtioMmioNetDevice,

num_virtio_mmio_socket_devices: u8,
virtio_mmio_socket_devices: [MAX_VIRTIO_MMIO_DEVICES_PER_TYPE]VirtioMmioSocketDevice,

num_virtio_mmio_sound_devices: u8,
virtio_mmio_sound_devices: [MAX_VIRTIO_MMIO_DEVICES_PER_TYPE]VirtioMmioSoundDevice,

num_linux_uio_regions: u8,
linux_uios: [MAX_LINUX_UIO_REGIONS]LinuxUioRegion,
};
Expand Down Expand Up @@ -220,7 +375,7 @@ pub fn addPassthroughDevice(system: *Self, device: *dtb.Node, options: Passthrou
}
}

fn addVirtioMmioDevice(system: *Self, device: *dtb.Node, t: Data.VirtioMmioDevice.Type) !void {
fn parseVirtioMmioDeviceRegs(system: *Self, device: *dtb.Node) !Data.VirtioMmioDeviceRegs {
const device_reg = device.prop(.Reg) orelse {
log.err("error adding virtIO device '{s}': missing 'reg' field on device node", .{device.name});
return error.InvalidVirtioDevice;
Expand All @@ -244,28 +399,35 @@ fn addVirtioMmioDevice(system: *Self, device: *dtb.Node, t: Data.VirtioMmioDevic

const irq = try dtb.parseIrq(system.sdf.arch, interrupts[0]);
// TODO: maybe use device resources like everything else? idk
system.data.virtio_mmio_devices[system.data.num_virtio_mmio_devices] = .{
.type = @intFromEnum(t),
return .{
.addr = device_paddr,
.size = @intCast(device_size),
.irq = irq.irq,
};
system.data.num_virtio_mmio_devices += 1;
}

pub fn addVirtioMmioConsole(system: *Self, device: *dtb.Node, serial: *sddf.Serial) !void {
try serial.addClient(system.vmm);
try system.addVirtioMmioDevice(device, .console);
system.data.virtio_mmio_console_devices[system.data.num_virtio_mmio_console_devices] = .{
.regs = try system.parseVirtioMmioDeviceRegs(device),
};
system.data.num_virtio_mmio_console_devices += 1;
}

pub fn addVirtioMmioBlk(system: *Self, device: *dtb.Node, blk: *sddf.Blk, options: sddf.Blk.ClientOptions) !void {
try blk.addClient(system.vmm, options);
try system.addVirtioMmioDevice(device, .blk);
system.data.virtio_mmio_block_devices[system.data.num_virtio_mmio_block_devices] = .{
.regs = try system.parseVirtioMmioDeviceRegs(device),
};
system.data.num_virtio_mmio_block_devices += 1;
}

pub fn addVirtioMmioNet(system: *Self, device: *dtb.Node, net: *sddf.Net, copier: *Pd, options: sddf.Net.ClientOptions) !void {
try net.addClientWithCopier(system.vmm, copier, options);
try system.addVirtioMmioDevice(device, .net);
system.data.virtio_mmio_net_devices[system.data.num_virtio_mmio_net_devices] = .{
.regs = try system.parseVirtioMmioDeviceRegs(device),
};
system.data.num_virtio_mmio_net_devices += 1;
}

pub fn addPassthroughIrq(system: *Self, irq: Irq) !void {
Expand Down Expand Up @@ -372,7 +534,7 @@ pub fn connect(system: *Self) !void {
if (sdf.arch.isArm()) {
const gic = dtb.ArmGic.fromDtb(sdf.arch, system.guest_dtb) orelse {
log.err("error connecting VMM '{s}' system: could not find GIC interrupt controller DTB node", .{vmm.name});
return error.MissinGicNode;
return error.MissingGicNode;
};
if (gic.hasMmioCpuInterface()) {
const gic_vcpu_mr_name = fmt(allocator, "{s}/vcpu", .{gic.node.name});
Expand Down