diff --git a/modules/net/quic/packet.c b/modules/net/quic/packet.c index d8528ff..b0f3f53 100644 --- a/modules/net/quic/packet.c +++ b/modules/net/quic/packet.c @@ -604,6 +604,7 @@ static int quic_packet_handshake_header_process(struct sock *sk, struct sk_buff struct quic_outqueue *outq = quic_outq(sk); u32 len = skb->len, version; struct quic_data token; + struct udphdr *uh; u64 length; quic_packet_reset(packet); @@ -626,7 +627,9 @@ static int quic_packet_handshake_header_process(struct sock *sk, struct sk_buff if (quic_packet_get_token(&token, &p, &len)) return -EINVAL; packet->level = QUIC_CRYPTO_INITIAL; - if (!quic_is_serv(sk) && token.len) { + uh = (struct udphdr *)(skb->head + cb->udph_offset); + if ((!quic_is_serv(sk) && token.len) || + (quic_is_serv(sk) && ntohs(uh->len) - sizeof(*uh) < QUIC_MIN_UDP_PAYLOAD)) { packet->errcode = QUIC_TRANSPORT_ERROR_PROTOCOL_VIOLATION; return -EINVAL; } @@ -1263,12 +1266,12 @@ static struct sk_buff *quic_packet_handshake_create(struct sock *sk) } len = packet->len; + hlen = QUIC_MIN_UDP_PAYLOAD - packet->taglen[1]; + if (level == QUIC_CRYPTO_INITIAL && len < hlen) { + len = hlen; + plen = len - packet->len; + } if (packet->frames) { - hlen = QUIC_MIN_UDP_PAYLOAD - packet->taglen[1]; - if (level == QUIC_CRYPTO_INITIAL && !quic_is_serv(sk) && len < hlen) { - len = hlen; - plen = len - packet->len; - } sent = quic_packet_sent_alloc(packet->frames); if (!sent) { quic_outq_retransmit_list(sk, &packet->frame_list);