aboutsummaryrefslogtreecommitdiff
path: root/drivers/firewire/fw-device-cdev.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/firewire/fw-device-cdev.c')
-rw-r--r--drivers/firewire/fw-device-cdev.c23
1 files changed, 19 insertions, 4 deletions
diff --git a/drivers/firewire/fw-device-cdev.c b/drivers/firewire/fw-device-cdev.c
index 6284375c639..1101ccd9b9c 100644
--- a/drivers/firewire/fw-device-cdev.c
+++ b/drivers/firewire/fw-device-cdev.c
@@ -406,8 +406,12 @@ static int ioctl_create_iso_context(struct client *client, void __user *arg)
if (copy_from_user(&request, arg, sizeof request))
return -EFAULT;
+ if (request.type > FW_ISO_CONTEXT_RECEIVE)
+ return -EINVAL;
+
client->iso_context = fw_iso_context_create(client->device->card,
- FW_ISO_CONTEXT_TRANSMIT,
+ request.type,
+ request.header_size,
iso_callback, client);
if (IS_ERR(client->iso_context))
return PTR_ERR(client->iso_context);
@@ -419,7 +423,7 @@ static int ioctl_queue_iso(struct client *client, void __user *arg)
{
struct fw_cdev_queue_iso request;
struct fw_cdev_iso_packet __user *p, *end, *next;
- unsigned long payload, payload_end;
+ unsigned long payload, payload_end, header_length;
int count;
struct {
struct fw_iso_packet packet;
@@ -456,12 +460,23 @@ static int ioctl_queue_iso(struct client *client, void __user *arg)
while (p < end) {
if (__copy_from_user(&u.packet, p, sizeof *p))
return -EFAULT;
+
+ if (client->iso_context->type == FW_ISO_CONTEXT_TRANSMIT) {
+ header_length = u.packet.header_length;
+ } else {
+ /* We require that header_length is a multiple of
+ * the fixed header size, ctx->header_size */
+ if (u.packet.header_length % client->iso_context->header_size != 0)
+ return -EINVAL;
+ header_length = 0;
+ }
+
next = (struct fw_cdev_iso_packet __user *)
- &p->header[u.packet.header_length / 4];
+ &p->header[header_length / 4];
if (next > end)
return -EINVAL;
if (__copy_from_user
- (u.packet.header, p->header, u.packet.header_length))
+ (u.packet.header, p->header, header_length))
return -EFAULT;
if (u.packet.skip &&
u.packet.header_length + u.packet.payload_length > 0)