vhost_user: re-factor feature/proto bit checking
We can wrap up the feature checking into helpers to reduce the amount of boilerplate code while enabling us to use a more idiomatic ?; exit path on error. Signed-off-by: Alex Bennée <alex.bennee@linaro.org>
This commit is contained in:
parent
7772f02e1a
commit
102d14e3b5
3 changed files with 78 additions and 94 deletions
|
|
@ -35,6 +35,24 @@ impl DummySlaveReqHandler {
|
|||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to check if VirtioFeature enabled
|
||||
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
|
||||
if self.acked_features & feat.bits() != 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::InvalidOperation)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to check is VhostUserProtocolFeatures enabled
|
||||
fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> {
|
||||
if self.acked_protocol_features & feat.bits() != 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::InvalidOperation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
|
||||
|
|
@ -190,9 +208,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
|
|||
fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
|
||||
// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
|
||||
// has been negotiated.
|
||||
if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
} else if index as usize >= self.queue_num || index as usize > self.queue_num {
|
||||
self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
|
||||
|
||||
if index as usize >= self.queue_num || index as usize > self.queue_num {
|
||||
return Err(Error::InvalidParam);
|
||||
}
|
||||
|
||||
|
|
@ -210,9 +228,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
|
|||
size: u32,
|
||||
_flags: VhostUserConfigFlags,
|
||||
) -> Result<Vec<u8>> {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
} else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
|
||||
|
||||
if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|
||||
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|
||||
|| size + offset > VHOST_USER_CONFIG_SIZE
|
||||
{
|
||||
|
|
@ -225,9 +243,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
|
|||
|
||||
fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
|
||||
let size = buf.len() as u32;
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
} else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
|
||||
|
||||
if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|
||||
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|
||||
|| size + offset > VHOST_USER_CONFIG_SIZE
|
||||
{
|
||||
|
|
|
|||
|
|
@ -339,10 +339,7 @@ impl VhostBackend for Master {
|
|||
impl VhostUserMaster for Master {
|
||||
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
|
||||
let mut node = self.node();
|
||||
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
|
||||
if node.virtio_features & flag == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
|
||||
let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
|
||||
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
|
||||
node.protocol_features = val.value;
|
||||
|
|
@ -356,10 +353,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
|
||||
let mut node = self.node();
|
||||
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
|
||||
if node.virtio_features & flag == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
|
||||
let val = VhostUserU64::new(features.bits());
|
||||
let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
|
||||
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
|
||||
|
|
@ -371,9 +365,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn get_queue_num(&mut self) -> Result<u64> {
|
||||
let mut node = self.node();
|
||||
if !node.is_feature_mq_available() {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
|
||||
|
||||
let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?;
|
||||
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
|
||||
|
|
@ -413,9 +405,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
let mut node = self.node();
|
||||
// depends on VhostUserProtocolFeatures::CONFIG
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
|
||||
|
||||
// vhost-user spec states that:
|
||||
// "Master payload: virtio device config space"
|
||||
|
|
@ -448,9 +438,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
let mut node = self.node();
|
||||
// depends on VhostUserProtocolFeatures::CONFIG
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
|
||||
|
||||
let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?;
|
||||
node.wait_for_ack(&hdr).map_err(|e| e.into())
|
||||
|
|
@ -458,9 +446,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> {
|
||||
let mut node = self.node();
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
|
||||
let fds = [fd.as_raw_fd()];
|
||||
let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
|
||||
node.wait_for_ack(&hdr).map_err(|e| e.into())
|
||||
|
|
@ -471,9 +457,7 @@ impl VhostUserMaster for Master {
|
|||
inflight: &VhostUserInflight,
|
||||
) -> Result<(VhostUserInflight, File)> {
|
||||
let mut node = self.node();
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
|
||||
|
||||
let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?;
|
||||
let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?;
|
||||
|
|
@ -486,9 +470,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> {
|
||||
let mut node = self.node();
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 {
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
|
||||
|
||||
if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0
|
||||
{
|
||||
|
|
@ -501,10 +483,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn get_max_mem_slots(&mut self) -> Result<u64> {
|
||||
let mut node = self.node();
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
|
||||
{
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
|
||||
|
||||
let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?;
|
||||
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
|
||||
|
|
@ -514,10 +493,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
|
||||
let mut node = self.node();
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
|
||||
{
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
|
||||
if region.memory_size == 0 || region.mmap_handle < 0 {
|
||||
return error_code(VhostUserError::InvalidParam);
|
||||
}
|
||||
|
|
@ -535,10 +511,7 @@ impl VhostUserMaster for Master {
|
|||
|
||||
fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
|
||||
let mut node = self.node();
|
||||
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
|
||||
{
|
||||
return error_code(VhostUserError::InvalidOperation);
|
||||
}
|
||||
node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
|
||||
if region.memory_size == 0 {
|
||||
return error_code(VhostUserError::InvalidParam);
|
||||
}
|
||||
|
|
@ -753,8 +726,20 @@ impl MasterInternal {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn is_feature_mq_available(&self) -> bool {
|
||||
self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0
|
||||
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> {
|
||||
if self.virtio_features & feat.bits() != 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(VhostUserError::InvalidOperation)
|
||||
}
|
||||
}
|
||||
|
||||
fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> VhostUserResult<()> {
|
||||
if self.acked_protocol_features & feat.bits() != 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(VhostUserError::InvalidOperation)
|
||||
}
|
||||
}
|
||||
|
||||
fn check_state(&self) -> VhostUserResult<()> {
|
||||
|
|
|
|||
|
|
@ -269,6 +269,22 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
|
|||
}
|
||||
}
|
||||
|
||||
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
|
||||
if self.acked_virtio_features & feat.bits() != 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::InvalidOperation)
|
||||
}
|
||||
}
|
||||
|
||||
fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> {
|
||||
if self.acked_protocol_features & feat.bits() != 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::InvalidOperation)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a vhost-user slave endpoint from a connected socket.
|
||||
pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self {
|
||||
Self::new(Endpoint::from_stream(socket), backend)
|
||||
|
|
@ -416,9 +432,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
|
|||
self.send_ack_message(&hdr, res)?;
|
||||
}
|
||||
MasterReq::GET_QUEUE_NUM => {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
|
||||
self.check_request_size(&hdr, size, 0)?;
|
||||
let num = self.backend.get_queue_num()?;
|
||||
let msg = VhostUserU64::new(num);
|
||||
|
|
@ -426,11 +440,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
|
|||
}
|
||||
MasterReq::SET_VRING_ENABLE => {
|
||||
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
|
||||
if self.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()
|
||||
== 0
|
||||
{
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
|
||||
let enable = match msg.num {
|
||||
1 => true,
|
||||
0 => false,
|
||||
|
|
@ -441,34 +451,24 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
|
|||
self.send_ack_message(&hdr, res)?;
|
||||
}
|
||||
MasterReq::GET_CONFIG => {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
|
||||
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
|
||||
self.get_config(&hdr, &buf)?;
|
||||
}
|
||||
MasterReq::SET_CONFIG => {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
|
||||
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
|
||||
let res = self.set_config(size, &buf);
|
||||
self.send_ack_message(&hdr, res)?;
|
||||
}
|
||||
MasterReq::SET_SLAVE_REQ_FD => {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
|
||||
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
|
||||
let res = self.set_slave_req_fd(files);
|
||||
self.send_ack_message(&hdr, res)?;
|
||||
}
|
||||
MasterReq::GET_INFLIGHT_FD => {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
|
||||
== 0
|
||||
{
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
|
||||
|
||||
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
|
||||
let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
|
||||
|
|
@ -477,35 +477,21 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
|
|||
.send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
|
||||
}
|
||||
MasterReq::SET_INFLIGHT_FD => {
|
||||
if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
|
||||
== 0
|
||||
{
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
|
||||
let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
|
||||
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
|
||||
let res = self.backend.set_inflight_fd(&msg, file);
|
||||
self.send_ack_message(&hdr, res)?;
|
||||
}
|
||||
MasterReq::GET_MAX_MEM_SLOTS => {
|
||||
if self.acked_protocol_features
|
||||
& VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
|
||||
== 0
|
||||
{
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
|
||||
self.check_request_size(&hdr, size, 0)?;
|
||||
let num = self.backend.get_max_mem_slots()?;
|
||||
let msg = VhostUserU64::new(num);
|
||||
self.send_reply_message(&hdr, &msg)?;
|
||||
}
|
||||
MasterReq::ADD_MEM_REG => {
|
||||
if self.acked_protocol_features
|
||||
& VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
|
||||
== 0
|
||||
{
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
|
||||
let mut files = files.ok_or(Error::InvalidParam)?;
|
||||
if files.len() != 1 {
|
||||
return Err(Error::InvalidParam);
|
||||
|
|
@ -516,12 +502,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
|
|||
self.send_ack_message(&hdr, res)?;
|
||||
}
|
||||
MasterReq::REM_MEM_REG => {
|
||||
if self.acked_protocol_features
|
||||
& VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
|
||||
== 0
|
||||
{
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
|
||||
|
||||
let msg =
|
||||
self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue