diff --git a/mmap.go b/mmap.go index 6708492..7d6abd0 100644 --- a/mmap.go +++ b/mmap.go @@ -15,6 +15,7 @@ import ( var ( uint32Size = uint32(unsafe.Sizeof(uint32(0))) sqeSize = uint32(unsafe.Sizeof(iouring_syscall.SubmissionQueueEntry{})) + cqeSize = uint32(unsafe.Sizeof(iouring_syscall.CompletionQueueEvent{})) ) func mmapIOURing(iour *IOURing) (err error) { @@ -26,7 +27,17 @@ func mmapIOURing(iour *IOURing) (err error) { iour.sq = new(SubmissionQueue) iour.cq = new(CompletionQueue) - if err = mmapSQ(iour); err != nil { + iour.sq.size = iour.params.SQOffset.Array + iour.params.SQEntries*uint32Size + iour.cq.size = iour.params.CQOffset.Cqes + iour.params.CQEntries*cqeSize + if (iour.params.Features & iouring_syscall.IORING_FEAT_SINGLE_MMAP) != 0 { + if iour.cq.size > iour.sq.size { + iour.sq.size = iour.cq.size + } else { + iour.cq.size = iour.sq.size + } + } + + if err = mmapSQ(iour.fd, iour.params, iour.sq); err != nil { return err } @@ -34,22 +45,18 @@ func mmapIOURing(iour *IOURing) (err error) { iour.cq.ptr = iour.sq.ptr } - if err = mmapCQ(iour); err != nil { + if err = mmapCQ(iour.fd, iour.params, iour.cq); err != nil { return err } - if err = mmapSQEs(iour); err != nil { + if err = mmapSQEs(iour.fd, iour.params, iour.sq); err != nil { return err } return nil } -func mmapSQ(iour *IOURing) (err error) { - sq := iour.sq - params := iour.params - - sq.size = params.SQOffset.Array + params.SQEntries*uint32Size - sq.ptr, err = mmap(iour.fd, sq.size, iouring_syscall.IORING_OFF_SQ_RING) +func mmapSQ(fd int, params *iouring_syscall.IOURingParams, sq *SubmissionQueue) (err error) { + sq.ptr, err = mmap(fd, sq.size, iouring_syscall.IORING_OFF_SQ_RING) if err != nil { return fmt.Errorf("mmap sq ring: %w", err) } @@ -66,17 +73,12 @@ func mmapSQ(iour *IOURing) (err error) { Len: int(params.SQEntries), Cap: int(params.SQEntries), })) - return nil } -func mmapCQ(iour *IOURing) (err error) { - params := iour.params - cq := iour.cq - - cq.size = params.CQOffset.Cqes + params.CQEntries*uint32Size +func mmapCQ(fd int, params *iouring_syscall.IOURingParams, cq *CompletionQueue) (err error) { if cq.ptr == 0 { - cq.ptr, err = mmap(iour.fd, cq.size, iouring_syscall.IORING_OFF_CQ_RING) + cq.ptr, err = mmap(fd, cq.size, iouring_syscall.IORING_OFF_CQ_RING) if err != nil { return fmt.Errorf("mmap cq ring: %w", err) } @@ -94,25 +96,21 @@ func mmapCQ(iour *IOURing) (err error) { Len: int(params.CQEntries), Cap: int(params.CQEntries), })) - return nil } -func mmapSQEs(iour *IOURing) error { - params := iour.params - - ptr, err := mmap(iour.fd, params.SQEntries*sqeSize, iouring_syscall.IORING_OFF_SQES) +func mmapSQEs(fd int, params *iouring_syscall.IOURingParams, sq *SubmissionQueue) (err error) { + ptr, err := mmap(fd, params.SQEntries*sqeSize, iouring_syscall.IORING_OFF_SQES) if err != nil { return fmt.Errorf("mmap sqe array: %w", err) } - iour.sq.sqes = *(*[]iouring_syscall.SubmissionQueueEntry)( + sq.sqes = *(*[]iouring_syscall.SubmissionQueueEntry)( unsafe.Pointer(&reflect.SliceHeader{ Data: ptr, Len: int(params.SQEntries), Cap: int(params.SQEntries), })) - return nil } @@ -141,7 +139,6 @@ func munmapIOURing(iour *IOURing) error { } iour.cq = nil } - return nil }