--- sys/kern/uipc_ktls.c.orig +++ sys/kern/uipc_ktls.c @@ -2419,8 +2419,10 @@ * Check if a mbuf chain is fully decrypted at the given offset and * length. Returns KTLS_MBUF_CRYPTO_ST_DECRYPTED if all data is * decrypted. KTLS_MBUF_CRYPTO_ST_MIXED if there is a mix of encrypted - * and decrypted data. Else KTLS_MBUF_CRYPTO_ST_ENCRYPTED if all data - * is encrypted. + * and decrypted data. KTLS_MBUF_CRYPTO_ST_ENCRYPTED if all data is + * encrypted. KTLS_MBUF_CRYPTO_ST_SHAREDMBUF if any mbuf points at + * shared data that must not be modified in place (non-anonymous + * M_EXTPG or sendfile M_EXT buffers). */ ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len) @@ -2436,6 +2438,13 @@ offset += len; for (; mb != NULL; mb = mb->m_next) { + if ((mb->m_flags & M_EXTPG) != 0 && + (mb->m_epg_flags & EPG_FLAG_ANON) == 0) + return (KTLS_MBUF_CRYPTO_ST_SHAREDMBUF); + if ((mb->m_flags & M_EXT) != 0 && + mb->m_ext.ext_type == EXT_SFBUF) + return (KTLS_MBUF_CRYPTO_ST_SHAREDMBUF); + m_flags_ored |= mb->m_flags; m_flags_anded &= mb->m_flags; @@ -2636,9 +2645,11 @@ record_type = hdr->tls_type; } break; - default: + case KTLS_MBUF_CRYPTO_ST_SHAREDMBUF: error = EINVAL; break; + default: + __assert_unreachable(); } if (error) { counter_u64_add(ktls_offload_failed_crypto, 1); --- sys/sys/ktls.h.orig +++ sys/sys/ktls.h @@ -241,6 +241,7 @@ KTLS_MBUF_CRYPTO_ST_MIXED = 0, KTLS_MBUF_CRYPTO_ST_ENCRYPTED = 1, KTLS_MBUF_CRYPTO_ST_DECRYPTED = -1, + KTLS_MBUF_CRYPTO_ST_SHAREDMBUF = -2, } ktls_mbuf_crypto_st_t; void ktls_check_rx(struct sockbuf *sb); --- tests/sys/kern/ktls_test.c.orig +++ tests/sys/kern/ktls_test.c @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -2817,6 +2818,97 @@ ATF_REQUIRE(close(s) == 0); } +/* + * Verify that the KTLS receive path does not overwrite data belonging + * to a file whose payload is transmitted over a loopback connection + * via plain sendfile. + */ +ATF_TC_WITHOUT_HEAD(ktls_receive_loopback_sendfile); +ATF_TC_BODY(ktls_receive_loopback_sendfile, tc) +{ + struct tls_enable en; + struct msghdr msg; + struct sf_hdtr hdtr; + struct iovec iov[2]; + uint64_t seqno; + off_t sbytes; + char cbuf[CMSG_SPACE(sizeof(struct tls_get_record))]; + char *plaintext, *ciphertext, *outbuf; + void *p; + const size_t payload_len = PAGE_SIZE; + ssize_t rv; + size_t len; + int mode, shm, sockets[2]; + socklen_t slen; + + ATF_REQUIRE_KTLS(); + seqno = random(); + build_tls_enable(tc, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0, + TLS_MINOR_VER_TWO, seqno, &en); + + len = tls_header_len(&en) + payload_len + tls_trailer_len(&en); + plaintext = alloc_buffer(payload_len); + ciphertext = malloc(len); + ATF_REQUIRE_INTEQ(len, encrypt_tls_record(tc, &en, TLS_RLTYPE_APP, + seqno, plaintext, payload_len, ciphertext, len, 0)); + + ATF_REQUIRE((shm = shm_open(SHM_ANON, O_RDWR, 0600)) > 0); + ATF_REQUIRE_INTEQ(0, ftruncate(shm, payload_len)); + ATF_REQUIRE((p = mmap(NULL, payload_len, PROT_READ | PROT_WRITE, + MAP_SHARED, shm, 0)) != MAP_FAILED); + memcpy(p, ciphertext + tls_header_len(&en), payload_len); + + ATF_REQUIRE_MSG(socketpair_tcp(sockets), "failed to create sockets"); + ATF_REQUIRE(setsockopt(sockets[0], IPPROTO_TCP, TCP_RXTLS_ENABLE, &en, + sizeof(en)) == 0); + slen = sizeof(mode); + ATF_REQUIRE_INTEQ(0, getsockopt(sockets[0], IPPROTO_TCP, TCP_RXTLS_MODE, + &mode, &slen)); + ATF_REQUIRE_INTEQ(TCP_TLS_MODE_SW, mode); + + fd_set_blocking(sockets[0]); + fd_set_blocking(sockets[1]); + + iov[0].iov_base = ciphertext; + iov[0].iov_len = tls_header_len(&en); + iov[1].iov_base = ciphertext + tls_header_len(&en) + payload_len; + iov[1].iov_len = tls_trailer_len(&en); + hdtr.headers = iov; + hdtr.hdr_cnt = 1; + hdtr.trailers = iov + 1; + hdtr.trl_cnt = 1; + debug_hexdump(tc, p, payload_len, "shm buffer before"); + ATF_REQUIRE_INTEQ(0, sendfile(shm, sockets[1], 0, payload_len, &hdtr, + &sbytes, 0)); + ATF_REQUIRE_INTEQ(sbytes, len); + + outbuf = calloc(payload_len, 1); + + memset(&msg, 0, sizeof(msg)); + + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + + iov[0].iov_base = outbuf; + iov[0].iov_len = payload_len; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + rv = recvmsg(sockets[0], &msg, 0); + if (rv >= 0) { + ATF_REQUIRE_INTEQ(payload_len, rv); + ATF_REQUIRE_INTEQ(0, memcmp(outbuf, plaintext, payload_len)); + } else + ATF_REQUIRE_ERRNO(EBADMSG, true); + + debug_hexdump(tc, p, payload_len, "shm buffer after"); + ATF_REQUIRE_INTEQ(0, memcmp(p, ciphertext + tls_header_len(&en), + payload_len)); + + close_sockets_ignore_errors(sockets); + (void)close(shm); +} + ATF_TP_ADD_TCS(tp) { /* Transmit tests */ @@ -2843,6 +2935,7 @@ /* Miscellaneous */ ATF_TP_ADD_TC(tp, ktls_sendto_baddst); ATF_TP_ADD_TC(tp, ktls_listening_socket); + ATF_TP_ADD_TC(tp, ktls_receive_loopback_sendfile); return (atf_no_error()); }