===MODEL PATCH ===

--- a/drivers/net/wireguard/send.c
+++ b/drivers/net/wireguard/send.c
@@ -168,6 +168,11 @@ static bool encrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
 	unsigned int padding_len, plaintext_len, trailer_len;
 	struct scatterlist sg[MAX_SKB_FRAGS + 8];
 	struct message_data *header;
+	/* If the skb is cloned or its head is shared, then we can't write
+	 * into it, and thus can't encrypt it in-place.
+	 */
+	if (unlikely(skb_cloned(skb) || skb_head_is_locked(skb)))
+		return false;
 	struct sk_buff *trailer;
 	int num_frags;
 

===ORIGINAL PATCH ===
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 43db442b1373..cdc96968b0f4 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -258,6 +258,8 @@ static void wg_setup(struct net_device *dev)
 	enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM |
 				    NETIF_F_SG | NETIF_F_GSO |
 				    NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA };
+	const int overhead = MESSAGE_MINIMUM_LENGTH + sizeof(struct udphdr) +
+			     max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
 
 	dev->netdev_ops = &netdev_ops;
 	dev->hard_header_len = 0;
@@ -271,9 +273,8 @@ static void wg_setup(struct net_device *dev)
 	dev->features |= WG_NETDEV_FEATURES;
 	dev->hw_features |= WG_NETDEV_FEATURES;
 	dev->hw_enc_features |= WG_NETDEV_FEATURES;
-	dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH -
-		   sizeof(struct udphdr) -
-		   max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
+	dev->mtu = ETH_DATA_LEN - overhead;
+	dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead;
 
 	SET_NETDEV_DEVTYPE(dev, &device_type);
 
diff --git a/drivers/net/wireguard/send.c b/drivers/net/wireguard/send.c
index c13260563446..7348c10cbae3 100644
--- a/drivers/net/wireguard/send.c
+++ b/drivers/net/wireguard/send.c
@@ -143,16 +143,22 @@ static void keep_key_fresh(struct wg_peer *peer)
 
 static unsigned int calculate_skb_padding(struct sk_buff *skb)
 {
+	unsigned int padded_size, last_unit = skb->len;
+
+	if (unlikely(!PACKET_CB(skb)->mtu))
+		return ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE) - last_unit;
+
 	/* We do this modulo business with the MTU, just in case the networking
 	 * layer gives us a packet that's bigger than the MTU. In that case, we
 	 * wouldn't want the final subtraction to overflow in the case of the
-	 * padded_size being clamped.
+	 * padded_size being clamped. Fortunately, that's very rarely the case,
+	 * so we optimize for that not happening.
 	 */
-	unsigned int last_unit = skb->len % PACKET_CB(skb)->mtu;
-	unsigned int padded_size = ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE);
+	if (unlikely(last_unit > PACKET_CB(skb)->mtu))
+		last_unit %= PACKET_CB(skb)->mtu;
 
-	if (padded_size > PACKET_CB(skb)->mtu)
-		padded_size = PACKET_CB(skb)->mtu;
+	padded_size = min(PACKET_CB(skb)->mtu,
+			  ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE));
 	return padded_size - last_unit;
 }
 

