Piggyback DTLS handshake in initial STUN packets

This change puts the DTLS handshake as payload of STUN packets with a custom STUN attribute (registered with the IANA) and starts the DTLS handshake before the ICE transport becomes writable. Effectively, STUN acts as a transport layer for DTLS during the handshake phase.

This will theoretically reduce the call setup time by one RTT for aggressive nomination or two RTTs for regular nomination.

The latest DTLS packet (flight) is cached and sent on every STUN request or response. DTLS packets are extracted from every authenticated STUN request or response and handled to the DTLS layer for processing.
The caching also increases the resilience to packet loss as STUN pacing is more aggressive (every 20ms) than the exponential backoff used by DTLS which should reduce call setup time in lossy networks.

If the other side of the connection does not support this feature the fallback to normal DTLS happens as soon as the ICE transport becomes writable. This also handles edge-cases like fragmentation of the DTLS handshake.

The feature is only supported when ECDSA certificates are used since RSA certificates are too large to transport as STUN attributes. The observed attributes for the server and client flights with the certificates were around 600 to 650 bytes. This may be further reduced by using raw public keys defined in RFC 7250.

This feature is disabled by default and guarded by the field trial
  WebRTC-IceHandshakeDtls
and requires experimentation and standardization before roll-out in the browser.

Parts of this landed in
  https://webrtc-review.googlesource.com/c/src/+/370679

BUG=webrtc:367395350

Change-Id: I4809438b2a267c4690a9b2bd6f1766d2f959500d
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/362480
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43742}
This commit is contained in:
Philipp Hancke 2025-01-14 12:16:55 -08:00 committed by WebRTC LUCI CQ
parent cfaba8fd2d
commit bce57cda1e
14 changed files with 618 additions and 24 deletions

View File

@ -191,6 +191,9 @@ ACTIVE_FIELD_TRIALS: FrozenSet[FieldTrial] = frozenset([
FieldTrial('WebRTC-Video-Vp9FlexibleMode', FieldTrial('WebRTC-Video-Vp9FlexibleMode',
329396373, 329396373,
date(2025, 6, 26)), date(2025, 6, 26)),
FieldTrial('WebRTC-IceHandshakeDtls',
367395350,
date(2026, 1, 1)),
# keep-sorted end # keep-sorted end
]) # yapf: disable ]) # yapf: disable

View File

@ -118,6 +118,7 @@ rtc_library("rtc_p2p") {
"../rtc_base:byte_order", "../rtc_base:byte_order",
"../rtc_base:callback_list", "../rtc_base:callback_list",
"../rtc_base:checks", "../rtc_base:checks",
"../rtc_base:copy_on_write_buffer",
"../rtc_base:crc32", "../rtc_base:crc32",
"../rtc_base:dscp", "../rtc_base:dscp",
"../rtc_base:event_tracer", "../rtc_base:event_tracer",
@ -331,6 +332,7 @@ rtc_library("connection") {
"../rtc_base:byte_buffer", "../rtc_base:byte_buffer",
"../rtc_base:callback_list", "../rtc_base:callback_list",
"../rtc_base:checks", "../rtc_base:checks",
"../rtc_base:copy_on_write_buffer",
"../rtc_base:crc32", "../rtc_base:crc32",
"../rtc_base:crypto_random", "../rtc_base:crypto_random",
"../rtc_base:dscp", "../rtc_base:dscp",
@ -340,6 +342,7 @@ rtc_library("connection") {
"../rtc_base:macromagic", "../rtc_base:macromagic",
"../rtc_base:mdns_responder_interface", "../rtc_base:mdns_responder_interface",
"../rtc_base:net_helper", "../rtc_base:net_helper",
"../rtc_base:net_helpers",
"../rtc_base:network", "../rtc_base:network",
"../rtc_base:network_constants", "../rtc_base:network_constants",
"../rtc_base:rate_tracker", "../rtc_base:rate_tracker",
@ -358,6 +361,7 @@ rtc_library("connection") {
"../rtc_base/third_party/base64", "../rtc_base/third_party/base64",
"../rtc_base/third_party/sigslot", "../rtc_base/third_party/sigslot",
"//third_party/abseil-cpp/absl/algorithm:container", "//third_party/abseil-cpp/absl/algorithm:container",
"//third_party/abseil-cpp/absl/base:core_headers",
"//third_party/abseil-cpp/absl/functional:any_invocable", "//third_party/abseil-cpp/absl/functional:any_invocable",
"//third_party/abseil-cpp/absl/strings:string_view", "//third_party/abseil-cpp/absl/strings:string_view",
] ]
@ -530,12 +534,14 @@ rtc_library("ice_transport_internal") {
":transport_description", ":transport_description",
"../api:array_view", "../api:array_view",
"../api:candidate", "../api:candidate",
"../api:field_trials_view",
"../api:rtc_error", "../api:rtc_error",
"../api/transport:enums", "../api/transport:enums",
"../rtc_base:callback_list", "../rtc_base:callback_list",
"../rtc_base:checks", "../rtc_base:checks",
"../rtc_base:network_constants", "../rtc_base:network_constants",
"../rtc_base:timeutils", "../rtc_base:timeutils",
"../rtc_base/network:received_packet",
"../rtc_base/system:rtc_export", "../rtc_base/system:rtc_export",
"../rtc_base/third_party/sigslot", "../rtc_base/third_party/sigslot",
"//third_party/abseil-cpp/absl/functional:any_invocable", "//third_party/abseil-cpp/absl/functional:any_invocable",
@ -563,6 +569,8 @@ rtc_library("p2p_transport_channel") {
":candidate_pair_interface", ":candidate_pair_interface",
":connection", ":connection",
":connection_info", ":connection_info",
":dtls_stun_piggyback_controller",
":dtls_utils",
":ice_agent_interface", ":ice_agent_interface",
":ice_controller_factory_interface", ":ice_controller_factory_interface",
":ice_controller_interface", ":ice_controller_interface",
@ -591,6 +599,7 @@ rtc_library("p2p_transport_channel") {
"../logging:ice_log", "../logging:ice_log",
"../rtc_base:async_packet_socket", "../rtc_base:async_packet_socket",
"../rtc_base:checks", "../rtc_base:checks",
"../rtc_base:copy_on_write_buffer",
"../rtc_base:dscp", "../rtc_base:dscp",
"../rtc_base:event_tracer", "../rtc_base:event_tracer",
"../rtc_base:ip_address", "../rtc_base:ip_address",
@ -808,10 +817,12 @@ rtc_library("dtls_stun_piggyback_controller") {
"../api:array_view", "../api:array_view",
"../api:sequence_checker", "../api:sequence_checker",
"../api/transport:stun_types", "../api/transport:stun_types",
"../api/transport:stun_types",
"../rtc_base:buffer", "../rtc_base:buffer",
"../rtc_base:byte_buffer", "../rtc_base:byte_buffer",
"../rtc_base:checks", "../rtc_base:checks",
"../rtc_base:logging", "../rtc_base:logging",
"../rtc_base:logging",
"../rtc_base:macromagic", "../rtc_base:macromagic",
"../rtc_base:stringutils", "../rtc_base:stringutils",
"../rtc_base/system:no_unique_address", "../rtc_base/system:no_unique_address",
@ -1165,6 +1176,7 @@ if (rtc_include_tests) {
"base/turn_server_unittest.cc", "base/turn_server_unittest.cc",
"base/wrapping_active_ice_controller_unittest.cc", "base/wrapping_active_ice_controller_unittest.cc",
"client/basic_port_allocator_unittest.cc", "client/basic_port_allocator_unittest.cc",
"dtls/dtls_ice_integrationtest.cc",
"dtls/dtls_stun_piggyback_controller_unittest.cc", "dtls/dtls_stun_piggyback_controller_unittest.cc",
"dtls/dtls_transport_unittest.cc", "dtls/dtls_transport_unittest.cc",
"dtls/dtls_utils_unittest.cc", "dtls/dtls_utils_unittest.cc",
@ -1176,14 +1188,19 @@ if (rtc_include_tests) {
":basic_ice_controller", ":basic_ice_controller",
":basic_packet_socket_factory", ":basic_packet_socket_factory",
":basic_port_allocator", ":basic_port_allocator",
":candidate_pair_interface",
":connection", ":connection",
":connection_info",
":dtls_stun_piggyback_controller", ":dtls_stun_piggyback_controller",
":dtls_transport", ":dtls_transport",
":dtls_transport_internal", ":dtls_transport_internal",
":dtls_utils", ":dtls_utils",
":fake_ice_transport", ":fake_ice_transport",
":fake_port_allocator", ":fake_port_allocator",
":ice_controller_factory_interface",
":ice_controller_interface",
":ice_credentials_iterator", ":ice_credentials_iterator",
":ice_switch_reason",
":ice_transport_internal", ":ice_transport_internal",
":p2p_constants", ":p2p_constants",
":p2p_server_utils", ":p2p_server_utils",
@ -1205,9 +1222,11 @@ if (rtc_include_tests) {
":turn_port", ":turn_port",
":wrapping_active_ice_controller", ":wrapping_active_ice_controller",
"../api:array_view", "../api:array_view",
"../api:async_dns_resolver",
"../api:candidate", "../api:candidate",
"../api:dtls_transport_interface", "../api:dtls_transport_interface",
"../api:field_trials_view", "../api:field_trials_view",
"../api:ice_transport_interface",
"../api:libjingle_peerconnection_api", "../api:libjingle_peerconnection_api",
"../api:mock_async_dns_resolver", "../api:mock_async_dns_resolver",
"../api:packet_socket_factory", "../api:packet_socket_factory",
@ -1215,6 +1234,7 @@ if (rtc_include_tests) {
"../api/crypto:options", "../api/crypto:options",
"../api/task_queue", "../api/task_queue",
"../api/task_queue:pending_task_safety_flag", "../api/task_queue:pending_task_safety_flag",
"../api/transport:enums",
"../api/transport:stun_types", "../api/transport:stun_types",
"../api/units:time_delta", "../api/units:time_delta",
"../rtc_base:async_packet_socket", "../rtc_base:async_packet_socket",
@ -1236,11 +1256,13 @@ if (rtc_include_tests) {
"../rtc_base:net_test_helpers", "../rtc_base:net_test_helpers",
"../rtc_base:network", "../rtc_base:network",
"../rtc_base:network_constants", "../rtc_base:network_constants",
"../rtc_base:network_route",
"../rtc_base:rtc_base_tests_utils", "../rtc_base:rtc_base_tests_utils",
"../rtc_base:socket", "../rtc_base:socket",
"../rtc_base:socket_adapters", "../rtc_base:socket_adapters",
"../rtc_base:socket_address", "../rtc_base:socket_address",
"../rtc_base:socket_address_pair", "../rtc_base:socket_address_pair",
"../rtc_base:socket_server",
"../rtc_base:ssl", "../rtc_base:ssl",
"../rtc_base:ssl_adapter", "../rtc_base:ssl_adapter",
"../rtc_base:stringutils", "../rtc_base:stringutils",

View File

@ -13,23 +13,43 @@
#include <math.h> #include <math.h>
#include <algorithm> #include <algorithm>
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/functional/any_invocable.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "api/array_view.h" #include "api/array_view.h"
#include "api/candidate.h"
#include "api/rtc_error.h"
#include "api/sequence_checker.h"
#include "api/task_queue/task_queue_base.h"
#include "api/transport/stun.h"
#include "api/units/timestamp.h" #include "api/units/timestamp.h"
#include "logging/rtc_event_log/events/rtc_event_ice_candidate_pair.h"
#include "logging/rtc_event_log/events/rtc_event_ice_candidate_pair_config.h"
#include "logging/rtc_event_log/ice_logger.h"
#include "p2p/base/connection_info.h"
#include "p2p/base/p2p_constants.h" #include "p2p/base/p2p_constants.h"
#include "p2p/base/p2p_transport_channel_ice_field_trials.h"
#include "p2p/base/port_interface.h"
#include "p2p/base/stun_request.h"
#include "p2p/base/transport_description.h"
#include "rtc_base/async_packet_socket.h"
#include "rtc_base/byte_buffer.h" #include "rtc_base/byte_buffer.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/crypto_random.h" #include "rtc_base/crypto_random.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/net_helper.h" #include "rtc_base/net_helper.h"
#include "rtc_base/net_helpers.h"
#include "rtc_base/network.h" #include "rtc_base/network.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/network/sent_packet.h" #include "rtc_base/network/sent_packet.h"
#include "rtc_base/network_constants.h" #include "rtc_base/network_constants.h"
#include "rtc_base/numerics/safe_minmax.h" #include "rtc_base/numerics/safe_minmax.h"
@ -39,6 +59,7 @@
#include "rtc_base/string_utils.h" #include "rtc_base/string_utils.h"
#include "rtc_base/strings/string_builder.h" #include "rtc_base/strings/string_builder.h"
#include "rtc_base/time_utils.h" #include "rtc_base/time_utils.h"
#include "rtc_base/weak_ptr.h"
namespace cricket { namespace cricket {
namespace { namespace {
@ -557,6 +578,15 @@ void Connection::OnReadPacket(const rtc::ReceivedPacket& packet) {
// This doesn't just check, it makes callbacks if transaction // This doesn't just check, it makes callbacks if transaction
// id's match. // id's match.
case STUN_BINDING_RESPONSE: case STUN_BINDING_RESPONSE:
if (dtls_stun_piggyback_consumer_) {
const StunByteStringAttribute* dtls_piggyback_attribute =
msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN);
const StunByteStringAttribute* dtls_piggyback_ack =
msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK);
dtls_stun_piggyback_consumer_(dtls_piggyback_attribute,
dtls_piggyback_ack);
}
ABSL_FALLTHROUGH_INTENDED;
case STUN_BINDING_ERROR_RESPONSE: case STUN_BINDING_ERROR_RESPONSE:
requests_.CheckResponse(msg.get()); requests_.CheckResponse(msg.get());
break; break;
@ -581,6 +611,36 @@ void Connection::OnReadPacket(const rtc::ReceivedPacket& packet) {
} }
} }
void Connection::MaybeAddDtlsPiggybackingAttributes(StunMessage* msg) {
if (!(dtls_stun_piggyback_data_producer_ &&
dtls_stun_piggyback_ack_producer_)) {
return;
}
std::optional<absl::string_view> dtls_piggyback_attr =
dtls_stun_piggyback_data_producer_(STUN_BINDING_RESPONSE);
std::optional<absl::string_view> dtls_piggyback_ack =
dtls_stun_piggyback_ack_producer_(STUN_BINDING_REQUEST);
size_t need_length =
(dtls_piggyback_attr
? dtls_piggyback_attr->length() + kStunAttributeHeaderSize
: 0) +
(dtls_piggyback_ack
? dtls_piggyback_ack->length() + kStunAttributeHeaderSize
: 0);
if (msg->length() + need_length > kMaxStunBindingLength) {
return;
}
if (dtls_piggyback_attr) {
msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
STUN_ATTR_META_DTLS_IN_STUN, *dtls_piggyback_attr));
}
if (dtls_piggyback_ack) {
msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
STUN_ATTR_META_DTLS_IN_STUN_ACK, *dtls_piggyback_ack));
}
}
void Connection::HandleStunBindingOrGoogPingRequest(IceMessage* msg) { void Connection::HandleStunBindingOrGoogPingRequest(IceMessage* msg) {
RTC_DCHECK_RUN_ON(network_thread_); RTC_DCHECK_RUN_ON(network_thread_);
// This connection should now be receiving. // This connection should now be receiving.
@ -623,6 +683,14 @@ void Connection::HandleStunBindingOrGoogPingRequest(IceMessage* msg) {
// This is a validated stun request from remote peer. // This is a validated stun request from remote peer.
if (msg->type() == STUN_BINDING_REQUEST) { if (msg->type() == STUN_BINDING_REQUEST) {
if (dtls_stun_piggyback_consumer_) {
const StunByteStringAttribute* dtls_piggyback_attribute =
msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN);
const StunByteStringAttribute* dtls_piggyback_ack =
msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK);
dtls_stun_piggyback_consumer_(dtls_piggyback_attribute,
dtls_piggyback_ack);
}
SendStunBindingResponse(msg); SendStunBindingResponse(msg);
} else { } else {
RTC_DCHECK(msg->type() == GOOG_PING_REQUEST); RTC_DCHECK(msg->type() == GOOG_PING_REQUEST);
@ -747,6 +815,8 @@ void Connection::SendStunBindingResponse(const StunMessage* message) {
} }
} }
MaybeAddDtlsPiggybackingAttributes(&response);
response.AddMessageIntegrity(local_candidate().password()); response.AddMessageIntegrity(local_candidate().password());
response.AddFingerprint(); response.AddFingerprint();
@ -1083,6 +1153,8 @@ std::unique_ptr<IceMessage> Connection::BuildPingRequest(
message->AddAttribute(std::move(delta)); message->AddAttribute(std::move(delta));
} }
MaybeAddDtlsPiggybackingAttributes(message.get());
message->AddMessageIntegrity(remote_candidate_.password()); message->AddMessageIntegrity(remote_candidate_.password());
message->AddFingerprint(); message->AddFingerprint();
@ -1483,6 +1555,21 @@ void Connection::OnConnectionRequestResponse(StunRequest* request,
} else if (delta_ack) { } else if (delta_ack) {
RTC_LOG(LS_ERROR) << "Discard GOOG_DELTA_ACK, no consumer"; RTC_LOG(LS_ERROR) << "Discard GOOG_DELTA_ACK, no consumer";
} }
if (dtls_stun_piggyback_consumer_) {
const bool sent_dtls_piggyback =
request->msg()->GetByteString(STUN_ATTR_META_DTLS_IN_STUN) != nullptr;
const bool sent_dtls_piggyback_ack =
request->msg()->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK) !=
nullptr;
const StunByteStringAttribute* dtls_piggyback_attr =
response->GetByteString(STUN_ATTR_META_DTLS_IN_STUN);
const StunByteStringAttribute* dtls_piggyback_ack =
response->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK);
if (sent_dtls_piggyback || sent_dtls_piggyback_ack) {
dtls_stun_piggyback_consumer_(dtls_piggyback_attr, dtls_piggyback_ack);
}
}
} }
void Connection::OnConnectionRequestErrorResponse(ConnectionRequest* request, void Connection::OnConnectionRequestErrorResponse(ConnectionRequest* request,

View File

@ -18,7 +18,6 @@
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string> #include <string>
#include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -53,6 +52,8 @@ namespace cricket {
// Version number for GOOG_PING, this is added to have the option of // Version number for GOOG_PING, this is added to have the option of
// adding other flavors in the future. // adding other flavors in the future.
constexpr int kGoogPingVersion = 1; constexpr int kGoogPingVersion = 1;
// 1200 is the "commonly used" MTU. Subtract M-I attribute (20+4) and FP (4+4).
constexpr int kMaxStunBindingLength = 1200 - 24 - 8;
// Forward declaration so that a ConnectionRequest can contain a Connection. // Forward declaration so that a ConnectionRequest can contain a Connection.
class Connection; class Connection;
@ -359,6 +360,23 @@ class RTC_EXPORT Connection : public CandidatePairInterface {
goog_delta_ack_consumer_ = std::nullopt; goog_delta_ack_consumer_ = std::nullopt;
} }
void RegisterDtlsPiggyback(
absl::AnyInvocable<std::optional<absl::string_view>(StunMessageType)>
data_producer,
absl::AnyInvocable<std::optional<absl::string_view>(StunMessageType)>
ack_producer,
absl::AnyInvocable<void(const StunByteStringAttribute*,
const StunByteStringAttribute*)> consumer) {
dtls_stun_piggyback_data_producer_ = std::move(data_producer);
dtls_stun_piggyback_ack_producer_ = std::move(ack_producer);
dtls_stun_piggyback_consumer_ = std::move(consumer);
}
void DeregisterDtlsPiggyback() {
dtls_stun_piggyback_consumer_ = nullptr;
dtls_stun_piggyback_data_producer_ = nullptr;
dtls_stun_piggyback_ack_producer_ = nullptr;
}
protected: protected:
// A ConnectionRequest is a simple STUN ping used to determine writability. // A ConnectionRequest is a simple STUN ping used to determine writability.
class ConnectionRequest; class ConnectionRequest;
@ -511,6 +529,15 @@ class RTC_EXPORT Connection : public CandidatePairInterface {
goog_delta_ack_consumer_; goog_delta_ack_consumer_;
absl::AnyInvocable<void(Connection*, const rtc::ReceivedPacket&)> absl::AnyInvocable<void(Connection*, const rtc::ReceivedPacket&)>
received_packet_callback_; received_packet_callback_;
void MaybeAddDtlsPiggybackingAttributes(StunMessage* msg);
absl::AnyInvocable<std::optional<absl::string_view>(StunMessageType)>
dtls_stun_piggyback_data_producer_ = nullptr;
absl::AnyInvocable<std::optional<absl::string_view>(StunMessageType)>
dtls_stun_piggyback_ack_producer_ = nullptr;
absl::AnyInvocable<void(const StunByteStringAttribute*,
const StunByteStringAttribute*)>
dtls_stun_piggyback_consumer_ = nullptr;
}; };
// ProxyConnection defers all the interesting work to the port. // ProxyConnection defers all the interesting work to the port.

View File

@ -23,6 +23,7 @@
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "api/array_view.h" #include "api/array_view.h"
#include "api/candidate.h" #include "api/candidate.h"
#include "api/field_trials_view.h"
#include "api/rtc_error.h" #include "api/rtc_error.h"
#include "api/transport/enums.h" #include "api/transport/enums.h"
#include "p2p/base/candidate_pair_interface.h" #include "p2p/base/candidate_pair_interface.h"
@ -34,6 +35,7 @@
#include "p2p/base/transport_description.h" #include "p2p/base/transport_description.h"
#include "rtc_base/callback_list.h" #include "rtc_base/callback_list.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/network_constants.h" #include "rtc_base/network_constants.h"
#include "rtc_base/system/rtc_export.h" #include "rtc_base/system/rtc_export.h"
#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/third_party/sigslot/sigslot.h"
@ -201,6 +203,9 @@ struct RTC_EXPORT IceConfig {
webrtc::VpnPreference vpn_preference = webrtc::VpnPreference::kDefault; webrtc::VpnPreference vpn_preference = webrtc::VpnPreference::kDefault;
// Experimental feature to transport the DTLS handshake in STUN packets.
bool dtls_handshake_in_stun = false;
IceConfig(); IceConfig();
IceConfig(int receiving_timeout_ms, IceConfig(int receiving_timeout_ms,
int backup_connection_ping_interval, int backup_connection_ping_interval,
@ -398,6 +403,15 @@ class RTC_EXPORT IceTransportInternal : public rtc::PacketTransportInternal {
virtual const webrtc::FieldTrialsView* field_trials() const { virtual const webrtc::FieldTrialsView* field_trials() const {
return nullptr; return nullptr;
} }
void SetPiggybackDtlsDataCallback(
absl::AnyInvocable<void(rtc::PacketTransportInternal* transport,
const rtc::ReceivedPacket& packet)> callback) {
RTC_DCHECK(callback == nullptr || !piggybacked_dtls_callback_);
piggybacked_dtls_callback_ = std::move(callback);
}
virtual void SetDtlsDataToPiggyback(rtc::ArrayView<const uint8_t>) {}
virtual void SetDtlsHandshakeComplete(bool is_dtls_client) {}
virtual bool IsDtlsPiggybackSupportedByPeer() { return false; }
protected: protected:
void SendGatheringStateEvent() { gathering_state_callback_list_.Send(this); } void SendGatheringStateEvent() { gathering_state_callback_list_.Send(this); }
@ -419,6 +433,9 @@ class RTC_EXPORT IceTransportInternal : public rtc::PacketTransportInternal {
absl::AnyInvocable<void(const cricket::CandidatePairChangeEvent&)> absl::AnyInvocable<void(const cricket::CandidatePairChangeEvent&)>
candidate_pair_change_callback_; candidate_pair_change_callback_;
absl::AnyInvocable<void(rtc::PacketTransportInternal*,
const rtc::ReceivedPacket&)>
piggybacked_dtls_callback_;
}; };
} // namespace cricket } // namespace cricket

View File

@ -55,6 +55,7 @@
#include "p2p/base/regathering_controller.h" #include "p2p/base/regathering_controller.h"
#include "p2p/base/transport_description.h" #include "p2p/base/transport_description.h"
#include "p2p/base/wrapping_active_ice_controller.h" #include "p2p/base/wrapping_active_ice_controller.h"
#include "p2p/dtls/dtls_utils.h"
#include "rtc_base/async_packet_socket.h" #include "rtc_base/async_packet_socket.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/dscp.h" #include "rtc_base/dscp.h"
@ -200,7 +201,16 @@ P2PTransportChannel::P2PTransportChannel(
true /* presume_writable_when_fully_relayed */, true /* presume_writable_when_fully_relayed */,
REGATHER_ON_FAILED_NETWORKS_INTERVAL, REGATHER_ON_FAILED_NETWORKS_INTERVAL,
RECEIVING_SWITCHING_DELAY), RECEIVING_SWITCHING_DELAY),
field_trials_(field_trials) { field_trials_(field_trials),
dtls_stun_piggyback_controller_(
[this](rtc::ArrayView<const uint8_t> piggybacked_dtls_packet) {
if (piggybacked_dtls_callback_ == nullptr) {
return;
}
piggybacked_dtls_callback_(
this, rtc::ReceivedPacket(piggybacked_dtls_packet,
rtc::SocketAddress()));
}) {
TRACE_EVENT0("webrtc", "P2PTransportChannel::P2PTransportChannel"); TRACE_EVENT0("webrtc", "P2PTransportChannel::P2PTransportChannel");
RTC_DCHECK(allocator_ != nullptr); RTC_DCHECK(allocator_ != nullptr);
// Validate IceConfig even for mostly built-in constant default values in case // Validate IceConfig even for mostly built-in constant default values in case
@ -310,6 +320,22 @@ void P2PTransportChannel::AddConnection(Connection* connection) {
[this](webrtc::RTCErrorOr<const StunUInt64Attribute*> delta_ack) { [this](webrtc::RTCErrorOr<const StunUInt64Attribute*> delta_ack) {
GoogDeltaAckReceived(std::move(delta_ack)); GoogDeltaAckReceived(std::move(delta_ack));
}); });
if (config_.dtls_handshake_in_stun) {
connection->RegisterDtlsPiggyback(
[this](StunMessageType stun_message_type) {
return dtls_stun_piggyback_controller_.GetDataToPiggyback(
stun_message_type);
},
[this](StunMessageType stun_message_type) {
return dtls_stun_piggyback_controller_.GetAckToPiggyback(
stun_message_type);
},
[this](const StunByteStringAttribute* data,
const StunByteStringAttribute* ack) {
dtls_stun_piggyback_controller_.ReportDataPiggybacked(data, ack);
});
}
LogCandidatePairConfig(connection, LogCandidatePairConfig(connection,
webrtc::IceCandidatePairConfigType::kAdded); webrtc::IceCandidatePairConfigType::kAdded);
@ -695,6 +721,11 @@ void P2PTransportChannel::SetIceConfig(const IceConfig& config) {
allocator_->SetVpnPreference(config_.vpn_preference); allocator_->SetVpnPreference(config_.vpn_preference);
ice_controller_->SetIceConfig(config_); ice_controller_->SetIceConfig(config_);
if (config_.dtls_handshake_in_stun != config.dtls_handshake_in_stun) {
config_.dtls_handshake_in_stun = config.dtls_handshake_in_stun;
RTC_LOG(LS_INFO) << "Set DTLS handshake in STUN to "
<< config.dtls_handshake_in_stun;
}
RTC_DCHECK(ValidateIceConfig(config_).ok()); RTC_DCHECK(ValidateIceConfig(config_).ok());
} }
@ -1609,6 +1640,16 @@ int P2PTransportChannel::SendPacket(const char* data,
error_ = ENOTCONN; error_ = ENOTCONN;
return -1; return -1;
} }
/*
* When trying DTLS-STUN piggyback we need to drop handshake packets
* as we start fresh if this fails.
*/
if (config_.dtls_handshake_in_stun && IsDtlsPiggybackSupportedByPeer() &&
IsDtlsHandshakePacket(
rtc::MakeArrayView(reinterpret_cast<const uint8_t*>(data), len))) {
RTC_LOG(LS_INFO) << "Dropping DTLS handshake while attemping DTLS-in-STUN";
return len;
}
packets_sent_++; packets_sent_++;
last_sent_packet_id_ = options.packet_id; last_sent_packet_id_ = options.packet_id;
@ -2151,6 +2192,7 @@ void P2PTransportChannel::RemoveConnection(Connection* connection) {
connection->DeregisterReceivedPacketCallback(); connection->DeregisterReceivedPacketCallback();
connections_.erase(it); connections_.erase(it);
connection->ClearStunDictConsumer(); connection->ClearStunDictConsumer();
connection->DeregisterDtlsPiggyback();
ice_controller_->OnConnectionDestroyed(connection); ice_controller_->OnConnectionDestroyed(connection);
} }
@ -2272,6 +2314,12 @@ void P2PTransportChannel::SetWritable(bool writable) {
SignalReadyToSend(this); SignalReadyToSend(this);
} }
SignalWritableState(this); SignalWritableState(this);
if (config_.dtls_handshake_in_stun && IsDtlsPiggybackSupportedByPeer()) {
// Need to STUN ping here to get the last bit of the DTLS handshake across
// as quickly as possible.
SendPingRequestInternal(selected_connection_);
}
} }
void P2PTransportChannel::SetReceiving(bool receiving) { void P2PTransportChannel::SetReceiving(bool receiving) {

View File

@ -58,6 +58,7 @@
#include "p2p/base/regathering_controller.h" #include "p2p/base/regathering_controller.h"
#include "p2p/base/stun_dictionary.h" #include "p2p/base/stun_dictionary.h"
#include "p2p/base/transport_description.h" #include "p2p/base/transport_description.h"
#include "p2p/dtls/dtls_stun_piggyback_controller.h"
#include "rtc_base/async_packet_socket.h" #include "rtc_base/async_packet_socket.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/dscp.h" #include "rtc_base/dscp.h"
@ -251,6 +252,18 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal,
const webrtc::FieldTrialsView* field_trials() const override { const webrtc::FieldTrialsView* field_trials() const override {
return field_trials_; return field_trials_;
} }
void SetDtlsDataToPiggyback(rtc::ArrayView<const uint8_t> data) override {
dtls_stun_piggyback_controller_.SetDataToPiggyback(data);
}
void SetDtlsHandshakeComplete(bool is_dtls_client) override {
dtls_stun_piggyback_controller_.SetDtlsHandshakeComplete(is_dtls_client);
}
bool IsDtlsPiggybackSupportedByPeer() override {
RTC_DCHECK_RUN_ON(network_thread_);
return config_.dtls_handshake_in_stun &&
dtls_stun_piggyback_controller_.state() !=
DtlsStunPiggybackController::State::OFF;
}
private: private:
P2PTransportChannel( P2PTransportChannel(
@ -515,6 +528,9 @@ class RTC_EXPORT P2PTransportChannel : public IceTransportInternal,
// A dictionary that tracks attributes from peer. // A dictionary that tracks attributes from peer.
StunDictionaryView stun_dict_view_; StunDictionaryView stun_dict_view_;
// A controller for piggybacking DTLS in STUN.
DtlsStunPiggybackController dtls_stun_piggyback_controller_;
}; };
} // namespace cricket } // namespace cricket

View File

@ -10,29 +10,57 @@
#include "p2p/base/p2p_transport_channel.h" #include "p2p/base/p2p_transport_channel.h"
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <list> #include <list>
#include <map>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <tuple>
#include <utility> #include <utility>
#include <vector>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/functional/any_invocable.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/async_dns_resolver.h"
#include "api/candidate.h"
#include "api/field_trials_view.h"
#include "api/ice_transport_interface.h"
#include "api/packet_socket_factory.h"
#include "api/scoped_refptr.h"
#include "api/task_queue/pending_task_safety_flag.h"
#include "api/test/mock_async_dns_resolver.h" #include "api/test/mock_async_dns_resolver.h"
#include "p2p/base/active_ice_controller_factory_interface.h" #include "api/transport/enums.h"
#include "p2p/base/active_ice_controller_interface.h" #include "api/transport/stun.h"
#include "api/units/time_delta.h"
#include "p2p/base/basic_ice_controller.h" #include "p2p/base/basic_ice_controller.h"
#include "p2p/base/basic_packet_socket_factory.h"
#include "p2p/base/candidate_pair_interface.h"
#include "p2p/base/connection.h" #include "p2p/base/connection.h"
#include "p2p/base/connection_info.h"
#include "p2p/base/fake_port_allocator.h" #include "p2p/base/fake_port_allocator.h"
#include "p2p/base/ice_controller_factory_interface.h"
#include "p2p/base/ice_controller_interface.h"
#include "p2p/base/ice_switch_reason.h"
#include "p2p/base/ice_transport_internal.h" #include "p2p/base/ice_transport_internal.h"
#include "p2p/base/mock_active_ice_controller.h" #include "p2p/base/mock_active_ice_controller.h"
#include "p2p/base/mock_ice_controller.h" #include "p2p/base/mock_ice_controller.h"
#include "p2p/base/p2p_constants.h"
#include "p2p/base/packet_transport_internal.h" #include "p2p/base/packet_transport_internal.h"
#include "p2p/base/port.h"
#include "p2p/base/port_allocator.h"
#include "p2p/base/port_interface.h"
#include "p2p/base/stun_dictionary.h"
#include "p2p/base/stun_server.h"
#include "p2p/base/test_stun_server.h" #include "p2p/base/test_stun_server.h"
#include "p2p/base/test_turn_server.h" #include "p2p/base/test_turn_server.h"
#include "p2p/base/transport_description.h"
#include "p2p/client/basic_port_allocator.h" #include "p2p/client/basic_port_allocator.h"
#include "rtc_base/byte_buffer.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/crypto_random.h"
#include "rtc_base/dscp.h" #include "rtc_base/dscp.h"
#include "rtc_base/fake_clock.h" #include "rtc_base/fake_clock.h"
#include "rtc_base/fake_mdns_responder.h" #include "rtc_base/fake_mdns_responder.h"
@ -40,19 +68,28 @@
#include "rtc_base/firewall_socket_server.h" #include "rtc_base/firewall_socket_server.h"
#include "rtc_base/gunit.h" #include "rtc_base/gunit.h"
#include "rtc_base/internal/default_socket_server.h" #include "rtc_base/internal/default_socket_server.h"
#include "rtc_base/ip_address.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/mdns_responder_interface.h" #include "rtc_base/mdns_responder_interface.h"
#include "rtc_base/nat_server.h"
#include "rtc_base/nat_socket_factory.h" #include "rtc_base/nat_socket_factory.h"
#include "rtc_base/nat_types.h"
#include "rtc_base/net_helper.h"
#include "rtc_base/net_helpers.h"
#include "rtc_base/network.h"
#include "rtc_base/network/received_packet.h" #include "rtc_base/network/received_packet.h"
#include "rtc_base/proxy_server.h" #include "rtc_base/network/sent_packet.h"
#include "rtc_base/network_constants.h"
#include "rtc_base/network_route.h"
#include "rtc_base/socket.h"
#include "rtc_base/socket_address.h" #include "rtc_base/socket_address.h"
#include "rtc_base/ssl_adapter.h" #include "rtc_base/socket_server.h"
#include "rtc_base/strings/string_builder.h" #include "rtc_base/third_party/sigslot/sigslot.h"
#include "rtc_base/thread.h" #include "rtc_base/thread.h"
#include "rtc_base/time_utils.h" #include "rtc_base/time_utils.h"
#include "rtc_base/virtual_socket_server.h" #include "rtc_base/virtual_socket_server.h"
#include "system_wrappers/include/metrics.h" #include "system_wrappers/include/metrics.h"
#include "test/gmock.h"
#include "test/gtest.h"
#include "test/scoped_key_value_config.h" #include "test/scoped_key_value_config.h"
namespace { namespace {
@ -6453,4 +6490,62 @@ TEST_F(P2PTransportChannelTest, TestIceNoOldCandidatesAfterIceRestart) {
DestroyChannels(); DestroyChannels();
} }
class P2PTransportChannelTestDtlsInStun : public P2PTransportChannelTestBase {
public:
P2PTransportChannelTestDtlsInStun() : P2PTransportChannelTestBase() {}
protected:
void Run(bool ep1_support, bool ep2_support) {
IceConfig ep1_config;
ep1_config.dtls_handshake_in_stun = ep1_support;
IceConfig ep2_config;
ep2_config.dtls_handshake_in_stun = ep2_support;
CreateChannels(ep1_config, ep2_config);
// DTLS server hello done message as test data.
std::vector<uint8_t> dtls_data = {
0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};
if (ep1_support) {
ep1_ch1()->SetDtlsDataToPiggyback(dtls_data);
}
if (ep2_support) {
ep2_ch1()->SetDtlsDataToPiggyback(dtls_data);
}
EXPECT_TRUE_SIMULATED_WAIT(CheckConnected(ep1_ch1(), ep2_ch1()),
kDefaultTimeout, clock_);
}
rtc::ScopedFakeClock clock_;
};
TEST_F(P2PTransportChannelTestDtlsInStun, NotSupportedByEither) {
Run(false, false);
EXPECT_FALSE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer());
EXPECT_FALSE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer());
DestroyChannels();
}
TEST_F(P2PTransportChannelTestDtlsInStun, SupportedByClient) {
Run(true, false);
EXPECT_FALSE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer());
EXPECT_FALSE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer());
DestroyChannels();
}
TEST_F(P2PTransportChannelTestDtlsInStun, SupportedByServer) {
Run(false, true);
EXPECT_FALSE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer());
EXPECT_FALSE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer());
DestroyChannels();
}
TEST_F(P2PTransportChannelTestDtlsInStun, SupportedByBoth) {
Run(true, true);
EXPECT_TRUE(ep1_ch1()->IsDtlsPiggybackSupportedByPeer());
EXPECT_TRUE(ep2_ch1()->IsDtlsPiggybackSupportedByPeer());
DestroyChannels();
}
} // namespace cricket } // namespace cricket

View File

@ -0,0 +1,190 @@
/*
* Copyright 2024 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <cstdint>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include "api/candidate.h"
#include "api/crypto/crypto_options.h"
#include "api/scoped_refptr.h"
#include "p2p/base/basic_packet_socket_factory.h"
#include "p2p/base/ice_transport_internal.h"
#include "p2p/base/p2p_transport_channel.h"
#include "p2p/base/port_allocator.h"
#include "p2p/base/transport_description.h"
#include "p2p/client/basic_port_allocator.h"
#include "p2p/dtls/dtls_transport.h"
#include "rtc_base/fake_clock.h"
#include "rtc_base/fake_network.h"
#include "rtc_base/gunit.h"
#include "rtc_base/rtc_certificate.h"
#include "rtc_base/socket_address.h"
#include "rtc_base/ssl_fingerprint.h"
#include "rtc_base/ssl_identity.h"
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/third_party/sigslot/sigslot.h"
#include "rtc_base/thread.h"
#include "rtc_base/virtual_socket_server.h"
#include "test/gtest.h"
namespace {
constexpr int kDefaultTimeout = 10000;
void SetRemoteFingerprintFromCert(
cricket::DtlsTransport& transport,
const rtc::scoped_refptr<rtc::RTCCertificate>& cert) {
std::unique_ptr<rtc::SSLFingerprint> fingerprint =
rtc::SSLFingerprint::CreateFromCertificate(*cert);
transport.SetRemoteParameters(
fingerprint->algorithm,
reinterpret_cast<const uint8_t*>(fingerprint->digest.data()),
fingerprint->digest.size(), std::nullopt);
}
} // namespace
namespace cricket {
class DtlsIceIntegrationTest
: public ::testing::TestWithParam<std::tuple<bool, bool>>,
public sigslot::has_slots<> {
public:
void CandidateC2S(IceTransportInternal*, const Candidate& c) {
thread_.PostTask([this, c = c]() { server_ice_->AddRemoteCandidate(c); });
}
void CandidateS2C(IceTransportInternal*, const Candidate& c) {
thread_.PostTask([this, c = c]() { client_ice_->AddRemoteCandidate(c); });
}
protected:
DtlsIceIntegrationTest()
: ss_(std::make_unique<rtc::VirtualSocketServer>()),
socket_factory_(
std::make_unique<rtc::BasicPacketSocketFactory>(ss_.get())),
thread_(ss_.get()),
client_allocator_(
std::make_unique<BasicPortAllocator>(&network_manager_,
socket_factory_.get())),
server_allocator_(
std::make_unique<BasicPortAllocator>(&network_manager_,
socket_factory_.get())),
client_ice_(
std::make_unique<P2PTransportChannel>("client_transport",
0,
client_allocator_.get())),
server_ice_(
std::make_unique<P2PTransportChannel>("server_transport",
0,
server_allocator_.get())),
client_dtls_(client_ice_.get(),
webrtc::CryptoOptions(),
/*event_log=*/nullptr,
rtc::SSL_PROTOCOL_DTLS_12),
server_dtls_(server_ice_.get(),
webrtc::CryptoOptions(),
/*event_log=*/nullptr,
rtc::SSL_PROTOCOL_DTLS_12),
client_ice_parameters_("c_ufrag",
"c_icepwd_something_something",
false),
server_ice_parameters_("s_ufrag",
"s_icepwd_something_something",
false),
client_dtls_stun_piggyback_(std::get<0>(GetParam())),
server_dtls_stun_piggyback_(std::get<1>(GetParam())) {
// Setup ICE.
client_ice_->SetIceParameters(client_ice_parameters_);
client_ice_->SetRemoteIceParameters(server_ice_parameters_);
client_ice_->SetIceRole(ICEROLE_CONTROLLING);
client_ice_->SignalCandidateGathered.connect(
this, &DtlsIceIntegrationTest::CandidateC2S);
server_ice_->SetIceParameters(server_ice_parameters_);
server_ice_->SetRemoteIceParameters(client_ice_parameters_);
server_ice_->SetIceRole(ICEROLE_CONTROLLED);
server_ice_->SignalCandidateGathered.connect(
this, &DtlsIceIntegrationTest::CandidateS2C);
// Setup DTLS.
auto client_certificate = rtc::RTCCertificate::Create(
rtc::SSLIdentity::Create("test", rtc::KT_DEFAULT));
client_dtls_.SetLocalCertificate(client_certificate);
client_dtls_.SetDtlsRole(rtc::SSL_SERVER);
auto server_certificate = rtc::RTCCertificate::Create(
rtc::SSLIdentity::Create("test", rtc::KT_DEFAULT));
server_dtls_.SetLocalCertificate(server_certificate);
server_dtls_.SetDtlsRole(rtc::SSL_CLIENT);
SetRemoteFingerprintFromCert(server_dtls_, client_certificate);
SetRemoteFingerprintFromCert(client_dtls_, server_certificate);
// Setup the network.
network_manager_.AddInterface(rtc::SocketAddress("192.168.1.1", 0));
client_allocator_->Initialize();
server_allocator_->Initialize();
}
~DtlsIceIntegrationTest() = default;
rtc::FakeNetworkManager network_manager_;
std::unique_ptr<rtc::VirtualSocketServer> ss_;
std::unique_ptr<rtc::BasicPacketSocketFactory> socket_factory_;
rtc::AutoSocketServerThread thread_;
std::unique_ptr<PortAllocator> client_allocator_;
std::unique_ptr<PortAllocator> server_allocator_;
std::unique_ptr<IceTransportInternal> client_ice_;
std::unique_ptr<IceTransportInternal> server_ice_;
DtlsTransport client_dtls_;
DtlsTransport server_dtls_;
IceParameters client_ice_parameters_;
IceParameters server_ice_parameters_;
bool client_dtls_stun_piggyback_;
bool server_dtls_stun_piggyback_;
rtc::ScopedFakeClock fake_clock_;
};
TEST_P(DtlsIceIntegrationTest, SmokeTest) {
cricket::IceConfig client_config;
client_config.dtls_handshake_in_stun = client_dtls_stun_piggyback_;
client_ice_->SetIceConfig(client_config);
cricket::IceConfig server_config;
server_config.dtls_handshake_in_stun = server_dtls_stun_piggyback_;
server_ice_->SetIceConfig(server_config);
client_ice_->MaybeStartGathering();
server_ice_->MaybeStartGathering();
// Note: this only reaches the pending piggybacking state.
EXPECT_TRUE_SIMULATED_WAIT(client_dtls_.writable() && server_dtls_.writable(),
kDefaultTimeout, fake_clock_);
EXPECT_EQ(client_ice_->IsDtlsPiggybackSupportedByPeer(),
client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_);
EXPECT_EQ(server_ice_->IsDtlsPiggybackSupportedByPeer(),
client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_);
}
INSTANTIATE_TEST_SUITE_P(DtlsStunPiggybackingIntegrationTest,
DtlsIceIntegrationTest,
::testing::Values(std::make_pair(false, false),
std::make_pair(true, false),
std::make_pair(false, true),
std::make_pair(true, true)));
} // namespace cricket

View File

@ -60,8 +60,10 @@ static const size_t kMaxPendingPackets = 2;
// Minimum and maximum values for the initial DTLS handshake timeout. We'll pick // Minimum and maximum values for the initial DTLS handshake timeout. We'll pick
// an initial timeout based on ICE RTT estimates, but clamp it to this range. // an initial timeout based on ICE RTT estimates, but clamp it to this range.
static const int kMinHandshakeTimeout = 50; static const int kMinHandshakeTimeoutMs = 50;
static const int kMaxHandshakeTimeout = 3000; static const int kMaxHandshakeTimeoutMs = 3000;
// This effectively disables the handshake timeout.
static const int kDisabledHandshakeTimeoutMs = 3600 * 1000 * 24;
static bool IsRtpPacket(rtc::ArrayView<const uint8_t> payload) { static bool IsRtpPacket(rtc::ArrayView<const uint8_t> payload) {
const uint8_t* u = payload.data(); const uint8_t* u = payload.data();
@ -96,6 +98,13 @@ rtc::StreamResult StreamInterfaceChannel::Write(
size_t& written, size_t& written,
int& /* error */) { int& /* error */) {
RTC_DCHECK_RUN_ON(&callback_sequence_); RTC_DCHECK_RUN_ON(&callback_sequence_);
if (IsDtlsHandshakePacket(data) &&
ice_transport_->IsDtlsPiggybackSupportedByPeer()) {
ice_transport_->SetDtlsDataToPiggyback(data);
// The ICE transport is responsible for dropping these packets.
}
// Always succeeds, since this is an unreliable transport anyway. // Always succeeds, since this is an unreliable transport anyway.
// TODO(zhihuang): Should this block if ice_transport_'s temporarily // TODO(zhihuang): Should this block if ice_transport_'s temporarily
// unwritable? // unwritable?
@ -150,6 +159,7 @@ DtlsTransport::DtlsTransport(IceTransportInternal* ice_transport,
DtlsTransport::~DtlsTransport() { DtlsTransport::~DtlsTransport() {
if (ice_transport_) { if (ice_transport_) {
ice_transport_->SetPiggybackDtlsDataCallback(nullptr);
ice_transport_->DeregisterReceivedPacketCallback(this); ice_transport_->DeregisterReceivedPacketCallback(this);
} }
} }
@ -531,6 +541,20 @@ void DtlsTransport::ConnectToIceTransport() {
this, &DtlsTransport::OnReceivingState); this, &DtlsTransport::OnReceivingState);
ice_transport_->SignalNetworkRouteChanged.connect( ice_transport_->SignalNetworkRouteChanged.connect(
this, &DtlsTransport::OnNetworkRouteChanged); this, &DtlsTransport::OnNetworkRouteChanged);
ice_transport_->SetPiggybackDtlsDataCallback(
[this](rtc::PacketTransportInternal* transport,
const rtc::ReceivedPacket& packet) {
RTC_DCHECK(dtls_active_);
RTC_DCHECK(IsDtlsHandshakePacket(packet.payload()));
if (!dtls_active_) {
// Not doing DTLS.
return;
}
if (!IsDtlsHandshakePacket(packet.payload())) {
return;
}
OnReadPacket(transport, packet);
});
} }
// The state transition logic here is as follows: // The state transition logic here is as follows:
@ -557,11 +581,37 @@ void DtlsTransport::OnWritableState(rtc::PacketTransportInternal* transport) {
return; return;
} }
// The opportunistic attempt to do DTLS piggybacking failed.
// Recreate the DTLS session. Note: this assumes we can consider
// the previous DTLS session state beyond repair and no packet
// reached the peer.
if (dtls_ && !was_ever_connected_ &&
!ice_transport_->IsDtlsPiggybackSupportedByPeer() &&
(dtls_state() == webrtc::DtlsTransportState::kConnecting ||
dtls_state() == webrtc::DtlsTransportState::kNew)) {
RTC_LOG(LS_ERROR) << "DTLS piggybacking not supported, restarting...";
ice_transport_->SetPiggybackDtlsDataCallback(nullptr);
dtls_.reset(nullptr);
set_dtls_state(webrtc::DtlsTransportState::kNew);
set_writable(false);
if (!SetupDtls()) {
RTC_LOG(LS_ERROR)
<< "Failed to setup DTLS again after attempted piggybacking.";
set_dtls_state(webrtc::DtlsTransportState::kFailed);
return;
}
// SetupDtls has called MaybeStartDtls() already.
return;
}
switch (dtls_state()) { switch (dtls_state()) {
case webrtc::DtlsTransportState::kNew: case webrtc::DtlsTransportState::kNew:
MaybeStartDtls(); MaybeStartDtls();
break; break;
case webrtc::DtlsTransportState::kConnected: case webrtc::DtlsTransportState::kConnected:
was_ever_connected_ = true;
// Note: SignalWritableState fired by set_writable. // Note: SignalWritableState fired by set_writable.
set_writable(ice_transport_->writable()); set_writable(ice_transport_->writable());
break; break;
@ -705,6 +755,7 @@ void DtlsTransport::OnDtlsEvent(int sig, int err) {
// sure we don't accidentally frob the state if it's closed. // sure we don't accidentally frob the state if it's closed.
set_dtls_state(webrtc::DtlsTransportState::kConnected); set_dtls_state(webrtc::DtlsTransportState::kConnected);
set_writable(true); set_writable(true);
ice_transport_->SetDtlsHandshakeComplete(dtls_role_ == rtc::SSL_CLIENT);
} }
} }
if (sig & rtc::SE_READ) { if (sig & rtc::SE_READ) {
@ -762,8 +813,13 @@ void DtlsTransport::OnNetworkRouteChanged(
} }
void DtlsTransport::MaybeStartDtls() { void DtlsTransport::MaybeStartDtls() {
if (dtls_ && ice_transport_->writable()) { RTC_DCHECK(ice_transport_);
ConfigureHandshakeTimeout(); // When adding the DTLS handshake in STUN we want to call StartSSL even
// before the ICE transport is ready.
bool start_early_for_dtls_in_stun =
ice_transport_->config().dtls_handshake_in_stun;
if (dtls_ && (ice_transport_->writable() || start_early_for_dtls_in_stun)) {
ConfigureHandshakeTimeout(start_early_for_dtls_in_stun);
if (dtls_->StartSSL()) { if (dtls_->StartSSL()) {
// This should never fail: // This should never fail:
@ -851,18 +907,26 @@ void DtlsTransport::OnDtlsHandshakeError(rtc::SSLHandshakeError error) {
SendDtlsHandshakeError(error); SendDtlsHandshakeError(error);
} }
void DtlsTransport::ConfigureHandshakeTimeout() { void DtlsTransport::ConfigureHandshakeTimeout(bool uses_dtls_in_stun) {
RTC_DCHECK(dtls_); RTC_DCHECK(dtls_);
std::optional<int> rtt = ice_transport_->GetRttEstimate(); std::optional<int> rtt_ms = ice_transport_->GetRttEstimate();
if (rtt) { if (uses_dtls_in_stun) {
// Configure a very high timeout to effectively disable the DTLS timeout
// and avoid fragmented resends. This is ok since DTLS-in-STUN caches
// the handshake pacets and resends them using the pacing of ICE.
RTC_LOG(LS_INFO) << ToString() << ": configuring DTLS handshake timeout "
<< kDisabledHandshakeTimeoutMs << "ms for DTLS-in-STUN";
dtls_->SetInitialRetransmissionTimeout(kDisabledHandshakeTimeoutMs);
} else if (rtt_ms) {
// Limit the timeout to a reasonable range in case the ICE RTT takes // Limit the timeout to a reasonable range in case the ICE RTT takes
// extreme values. // extreme values.
int initial_timeout = std::max(kMinHandshakeTimeout, int initial_timeout_ms =
std::min(kMaxHandshakeTimeout, 2 * (*rtt))); std::max(kMinHandshakeTimeoutMs,
std::min(kMaxHandshakeTimeoutMs, 2 * (*rtt_ms)));
RTC_LOG(LS_INFO) << ToString() << ": configuring DTLS handshake timeout " RTC_LOG(LS_INFO) << ToString() << ": configuring DTLS handshake timeout "
<< initial_timeout << " based on ICE RTT " << *rtt; << initial_timeout_ms << "ms based on ICE RTT " << *rtt_ms;
dtls_->SetInitialRetransmissionTimeout(initial_timeout); dtls_->SetInitialRetransmissionTimeout(initial_timeout_ms);
} else { } else {
RTC_LOG(LS_INFO) RTC_LOG(LS_INFO)
<< ToString() << ToString()

View File

@ -237,7 +237,7 @@ class DtlsTransport : public DtlsTransportInternal {
void MaybeStartDtls(); void MaybeStartDtls();
bool HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload); bool HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload);
void OnDtlsHandshakeError(rtc::SSLHandshakeError error); void OnDtlsHandshakeError(rtc::SSLHandshakeError error);
void ConfigureHandshakeTimeout(); void ConfigureHandshakeTimeout(bool uses_dtls_in_stun);
void set_receiving(bool receiving); void set_receiving(bool receiving);
void set_writable(bool writable); void set_writable(bool writable);
@ -269,6 +269,8 @@ class DtlsTransport : public DtlsTransportInternal {
bool receiving_ = false; bool receiving_ = false;
bool writable_ = false; bool writable_ = false;
bool was_ever_connected_ = false;
webrtc::RtcEventLog* const event_log_; webrtc::RtcEventLog* const event_log_;
}; };

View File

@ -21,6 +21,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/functional/any_invocable.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "api/array_view.h" #include "api/array_view.h"
#include "api/crypto/crypto_options.h" #include "api/crypto/crypto_options.h"
@ -34,6 +36,8 @@
#include "p2p/dtls/dtls_utils.h" #include "p2p/dtls/dtls_utils.h"
#include "rtc_base/buffer.h" #include "rtc_base/buffer.h"
#include "rtc_base/byte_order.h" #include "rtc_base/byte_order.h"
#include "rtc_base/checks.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/fake_clock.h" #include "rtc_base/fake_clock.h"
#include "rtc_base/gunit.h" #include "rtc_base/gunit.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
@ -44,7 +48,6 @@
#include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/third_party/sigslot/sigslot.h"
#include "rtc_base/thread.h" #include "rtc_base/thread.h"
#include "test/field_trial.h"
#include "test/gtest.h" #include "test/gtest.h"
#define MAYBE_SKIP_TEST(feature) \ #define MAYBE_SKIP_TEST(feature) \

View File

@ -301,6 +301,8 @@ cricket::IceConfig ParseIceConfig(
ice_config.network_preference = config.network_preference; ice_config.network_preference = config.network_preference;
ice_config.stable_writable_connection_ping_interval = ice_config.stable_writable_connection_ping_interval =
config.stable_writable_connection_ping_interval_ms; config.stable_writable_connection_ping_interval_ms;
ice_config.dtls_handshake_in_stun =
false; // Filled in later based on field trial.
return ice_config; return ice_config;
} }
@ -916,7 +918,11 @@ JsepTransportController* PeerConnection::InitializeTransportController_n(
})); }));
}); });
transport_controller_->SetIceConfig(ParseIceConfig(configuration)); auto ice_config = ParseIceConfig(configuration);
ice_config.dtls_handshake_in_stun =
CanAttemptDtlsStunPiggybacking(configuration);
transport_controller_->SetIceConfig(ice_config);
return transport_controller_.get(); return transport_controller_.get();
} }
@ -1644,6 +1650,8 @@ RTCError PeerConnection::SetConfiguration(
modified_config.GetTurnPortPrunePolicy() != modified_config.GetTurnPortPrunePolicy() !=
configuration_.GetTurnPortPrunePolicy(); configuration_.GetTurnPortPrunePolicy();
cricket::IceConfig ice_config = ParseIceConfig(modified_config); cricket::IceConfig ice_config = ParseIceConfig(modified_config);
ice_config.dtls_handshake_in_stun =
CanAttemptDtlsStunPiggybacking(modified_config);
// Apply part of the configuration on the network thread. In theory this // Apply part of the configuration on the network thread. In theory this
// shouldn't fail. // shouldn't fail.
@ -3122,4 +3130,14 @@ PeerConnection::InitializeUnDemuxablePacketHandler() {
}; };
} }
bool PeerConnection::CanAttemptDtlsStunPiggybacking(
const RTCConfiguration& configuration) {
// Enable DTLS-in-STUN only if no certificates were passed those
// may be RSA certificates and this feature only works with small
// ECDSA certificates. Determining the type of the key is
// not trivially possible at this point.
return dtls_enabled_ && configuration.certificates.empty() &&
env_.field_trials().IsEnabled("WebRTC-IceHandshakeDtls");
}
} // namespace webrtc } // namespace webrtc

View File

@ -722,6 +722,8 @@ class PeerConnection : public PeerConnectionInternal,
PayloadTypePicker payload_type_picker_; PayloadTypePicker payload_type_picker_;
// This variable needs to be the last one in the class. // This variable needs to be the last one in the class.
rtc::WeakPtrFactory<PeerConnection> weak_factory_; rtc::WeakPtrFactory<PeerConnection> weak_factory_;
bool CanAttemptDtlsStunPiggybacking(const RTCConfiguration& configuration);
}; };
} // namespace webrtc } // namespace webrtc