proxying_websocket.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. // Copyright 2020 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "electron/shell/browser/net/proxying_websocket.h"
  5. #include <utility>
  6. #include "base/bind.h"
  7. #include "base/strings/string_util.h"
  8. #include "base/strings/stringprintf.h"
  9. #include "content/public/browser/browser_thread.h"
  10. #include "extensions/browser/extension_navigation_ui_data.h"
  11. #include "net/base/ip_endpoint.h"
  12. #include "net/http/http_util.h"
  13. #include "services/network/public/mojom/url_response_head.mojom.h"
  14. namespace electron {
  15. ProxyingWebSocket::ProxyingWebSocket(
  16. WebRequestAPI* web_request_api,
  17. WebSocketFactory factory,
  18. const network::ResourceRequest& request,
  19. mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
  20. handshake_client,
  21. bool has_extra_headers,
  22. int process_id,
  23. int render_frame_id,
  24. content::BrowserContext* browser_context,
  25. uint64_t* request_id_generator)
  26. : web_request_api_(web_request_api),
  27. request_(request),
  28. factory_(std::move(factory)),
  29. forwarding_handshake_client_(std::move(handshake_client)),
  30. request_headers_(request.headers),
  31. response_(network::mojom::URLResponseHead::New()),
  32. has_extra_headers_(has_extra_headers),
  33. info_(extensions::WebRequestInfoInitParams(
  34. ++(*request_id_generator),
  35. process_id,
  36. render_frame_id,
  37. nullptr,
  38. request,
  39. /*is_download=*/false,
  40. /*is_async=*/true,
  41. /*is_service_worker_script=*/false,
  42. /*navigation_id=*/absl::nullopt,
  43. /*ukm_source_id=*/ukm::kInvalidSourceIdObj)) {}
  44. ProxyingWebSocket::~ProxyingWebSocket() {
  45. if (on_before_send_headers_callback_) {
  46. std::move(on_before_send_headers_callback_)
  47. .Run(net::ERR_ABORTED, absl::nullopt);
  48. }
  49. if (on_headers_received_callback_) {
  50. std::move(on_headers_received_callback_)
  51. .Run(net::ERR_ABORTED, absl::nullopt, GURL());
  52. }
  53. }
  54. void ProxyingWebSocket::Start() {
  55. // If the header client will be used, we start the request immediately, and
  56. // OnBeforeSendHeaders and OnSendHeaders will be handled there. Otherwise,
  57. // send these events before the request starts.
  58. base::RepeatingCallback<void(int)> continuation;
  59. if (has_extra_headers_) {
  60. continuation = base::BindRepeating(
  61. &ProxyingWebSocket::ContinueToStartRequest, weak_factory_.GetWeakPtr());
  62. } else {
  63. continuation =
  64. base::BindRepeating(&ProxyingWebSocket::OnBeforeRequestComplete,
  65. weak_factory_.GetWeakPtr());
  66. }
  67. int result = web_request_api_->OnBeforeRequest(&info_, request_, continuation,
  68. &redirect_url_);
  69. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  70. OnError(result);
  71. return;
  72. }
  73. if (result == net::ERR_IO_PENDING) {
  74. return;
  75. }
  76. DCHECK_EQ(net::OK, result);
  77. continuation.Run(net::OK);
  78. }
  79. void ProxyingWebSocket::OnOpeningHandshakeStarted(
  80. network::mojom::WebSocketHandshakeRequestPtr request) {
  81. DCHECK(forwarding_handshake_client_);
  82. forwarding_handshake_client_->OnOpeningHandshakeStarted(std::move(request));
  83. }
  84. void ProxyingWebSocket::ContinueToHeadersReceived() {
  85. auto continuation =
  86. base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedComplete,
  87. weak_factory_.GetWeakPtr());
  88. info_.AddResponseInfoFromResourceResponse(*response_);
  89. int result = web_request_api_->OnHeadersReceived(
  90. &info_, request_, continuation, response_->headers.get(),
  91. &override_headers_, &redirect_url_);
  92. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  93. OnError(result);
  94. return;
  95. }
  96. PauseIncomingMethodCallProcessing();
  97. if (result == net::ERR_IO_PENDING)
  98. return;
  99. DCHECK_EQ(net::OK, result);
  100. OnHeadersReceivedComplete(net::OK);
  101. }
  102. void ProxyingWebSocket::OnFailure(const std::string& message,
  103. int32_t net_error,
  104. int32_t response_code) {}
  105. void ProxyingWebSocket::OnConnectionEstablished(
  106. mojo::PendingRemote<network::mojom::WebSocket> websocket,
  107. mojo::PendingReceiver<network::mojom::WebSocketClient> client_receiver,
  108. network::mojom::WebSocketHandshakeResponsePtr response,
  109. mojo::ScopedDataPipeConsumerHandle readable,
  110. mojo::ScopedDataPipeProducerHandle writable) {
  111. DCHECK(forwarding_handshake_client_);
  112. DCHECK(!is_done_);
  113. is_done_ = true;
  114. websocket_ = std::move(websocket);
  115. client_receiver_ = std::move(client_receiver);
  116. handshake_response_ = std::move(response);
  117. readable_ = std::move(readable);
  118. writable_ = std::move(writable);
  119. response_->remote_endpoint = handshake_response_->remote_endpoint;
  120. // response_->headers will be set in OnBeforeSendHeaders if
  121. // |receiver_as_header_client_| is set.
  122. if (receiver_as_header_client_.is_bound()) {
  123. ContinueToCompleted();
  124. return;
  125. }
  126. response_->headers =
  127. base::MakeRefCounted<net::HttpResponseHeaders>(base::StringPrintf(
  128. "HTTP/%d.%d %d %s", handshake_response_->http_version.major_value(),
  129. handshake_response_->http_version.minor_value(),
  130. handshake_response_->status_code,
  131. handshake_response_->status_text.c_str()));
  132. for (const auto& header : handshake_response_->headers)
  133. response_->headers->AddHeader(header->name, header->value);
  134. ContinueToHeadersReceived();
  135. }
  136. void ProxyingWebSocket::ContinueToCompleted() {
  137. DCHECK(forwarding_handshake_client_);
  138. DCHECK(is_done_);
  139. web_request_api_->OnCompleted(&info_, request_, net::ERR_WS_UPGRADE);
  140. forwarding_handshake_client_->OnConnectionEstablished(
  141. std::move(websocket_), std::move(client_receiver_),
  142. std::move(handshake_response_), std::move(readable_),
  143. std::move(writable_));
  144. // Deletes |this|.
  145. delete this;
  146. }
  147. void ProxyingWebSocket::OnAuthRequired(
  148. const net::AuthChallengeInfo& auth_info,
  149. const scoped_refptr<net::HttpResponseHeaders>& headers,
  150. const net::IPEndPoint& remote_endpoint,
  151. OnAuthRequiredCallback callback) {
  152. if (!callback) {
  153. OnError(net::ERR_FAILED);
  154. return;
  155. }
  156. response_->headers = headers;
  157. response_->remote_endpoint = remote_endpoint;
  158. auth_required_callback_ = std::move(callback);
  159. auto continuation =
  160. base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedCompleteForAuth,
  161. weak_factory_.GetWeakPtr(), auth_info);
  162. info_.AddResponseInfoFromResourceResponse(*response_);
  163. int result = web_request_api_->OnHeadersReceived(
  164. &info_, request_, continuation, response_->headers.get(),
  165. &override_headers_, &redirect_url_);
  166. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  167. OnError(result);
  168. return;
  169. }
  170. PauseIncomingMethodCallProcessing();
  171. if (result == net::ERR_IO_PENDING)
  172. return;
  173. DCHECK_EQ(net::OK, result);
  174. OnHeadersReceivedCompleteForAuth(auth_info, net::OK);
  175. }
  176. void ProxyingWebSocket::OnBeforeSendHeaders(
  177. const net::HttpRequestHeaders& headers,
  178. OnBeforeSendHeadersCallback callback) {
  179. DCHECK(receiver_as_header_client_.is_bound());
  180. request_headers_ = headers;
  181. on_before_send_headers_callback_ = std::move(callback);
  182. OnBeforeRequestComplete(net::OK);
  183. }
  184. void ProxyingWebSocket::OnHeadersReceived(const std::string& headers,
  185. const net::IPEndPoint& endpoint,
  186. OnHeadersReceivedCallback callback) {
  187. DCHECK(receiver_as_header_client_.is_bound());
  188. on_headers_received_callback_ = std::move(callback);
  189. response_->headers = base::MakeRefCounted<net::HttpResponseHeaders>(headers);
  190. ContinueToHeadersReceived();
  191. }
  192. void ProxyingWebSocket::StartProxying(
  193. WebRequestAPI* web_request_api,
  194. WebSocketFactory factory,
  195. const GURL& url,
  196. const net::SiteForCookies& site_for_cookies,
  197. const absl::optional<std::string>& user_agent,
  198. mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
  199. handshake_client,
  200. bool has_extra_headers,
  201. int process_id,
  202. int render_frame_id,
  203. const url::Origin& origin,
  204. content::BrowserContext* browser_context,
  205. uint64_t* request_id_generator) {
  206. DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  207. network::ResourceRequest request;
  208. request.url = url;
  209. request.site_for_cookies = site_for_cookies;
  210. if (user_agent) {
  211. request.headers.SetHeader(net::HttpRequestHeaders::kUserAgent, *user_agent);
  212. }
  213. request.request_initiator = origin;
  214. auto* proxy = new ProxyingWebSocket(
  215. web_request_api, std::move(factory), request, std::move(handshake_client),
  216. has_extra_headers, process_id, render_frame_id, browser_context,
  217. request_id_generator);
  218. proxy->Start();
  219. }
  220. void ProxyingWebSocket::OnBeforeRequestComplete(int error_code) {
  221. DCHECK(receiver_as_header_client_.is_bound() ||
  222. !receiver_as_handshake_client_.is_bound());
  223. DCHECK(info_.url.SchemeIsWSOrWSS());
  224. if (error_code != net::OK) {
  225. OnError(error_code);
  226. return;
  227. }
  228. auto continuation =
  229. base::BindRepeating(&ProxyingWebSocket::OnBeforeSendHeadersComplete,
  230. weak_factory_.GetWeakPtr());
  231. info_.AddResponseInfoFromResourceResponse(*response_);
  232. int result = web_request_api_->OnBeforeSendHeaders(
  233. &info_, request_, continuation, &request_headers_);
  234. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  235. OnError(result);
  236. return;
  237. }
  238. if (result == net::ERR_IO_PENDING)
  239. return;
  240. DCHECK_EQ(net::OK, result);
  241. OnBeforeSendHeadersComplete(std::set<std::string>(), std::set<std::string>(),
  242. net::OK);
  243. }
  244. void ProxyingWebSocket::OnBeforeSendHeadersComplete(
  245. const std::set<std::string>& removed_headers,
  246. const std::set<std::string>& set_headers,
  247. int error_code) {
  248. DCHECK(receiver_as_header_client_.is_bound() ||
  249. !receiver_as_handshake_client_.is_bound());
  250. if (error_code != net::OK) {
  251. OnError(error_code);
  252. return;
  253. }
  254. if (receiver_as_header_client_.is_bound()) {
  255. CHECK(on_before_send_headers_callback_);
  256. std::move(on_before_send_headers_callback_)
  257. .Run(error_code, request_headers_);
  258. }
  259. info_.AddResponseInfoFromResourceResponse(*response_);
  260. web_request_api_->OnSendHeaders(&info_, request_, request_headers_);
  261. if (!receiver_as_header_client_.is_bound())
  262. ContinueToStartRequest(net::OK);
  263. }
  264. void ProxyingWebSocket::ContinueToStartRequest(int error_code) {
  265. if (error_code != net::OK) {
  266. OnError(error_code);
  267. return;
  268. }
  269. base::flat_set<std::string> used_header_names;
  270. std::vector<network::mojom::HttpHeaderPtr> additional_headers;
  271. for (net::HttpRequestHeaders::Iterator it(request_headers_); it.GetNext();) {
  272. additional_headers.push_back(
  273. network::mojom::HttpHeader::New(it.name(), it.value()));
  274. used_header_names.insert(base::ToLowerASCII(it.name()));
  275. }
  276. for (const auto& header : additional_headers_) {
  277. if (!used_header_names.contains(base::ToLowerASCII(header->name))) {
  278. additional_headers.push_back(
  279. network::mojom::HttpHeader::New(header->name, header->value));
  280. }
  281. }
  282. mojo::PendingRemote<network::mojom::TrustedHeaderClient>
  283. trusted_header_client = mojo::NullRemote();
  284. if (has_extra_headers_) {
  285. trusted_header_client =
  286. receiver_as_header_client_.BindNewPipeAndPassRemote();
  287. }
  288. std::move(factory_).Run(
  289. info_.url, std::move(additional_headers),
  290. receiver_as_handshake_client_.BindNewPipeAndPassRemote(),
  291. receiver_as_auth_handler_.BindNewPipeAndPassRemote(),
  292. std::move(trusted_header_client));
  293. // Here we detect mojo connection errors on |receiver_as_handshake_client_|.
  294. // See also CreateWebSocket in
  295. // //network/services/public/mojom/network_context.mojom.
  296. receiver_as_handshake_client_.set_disconnect_with_reason_handler(
  297. base::BindOnce(&ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason,
  298. base::Unretained(this)));
  299. forwarding_handshake_client_.set_disconnect_handler(base::BindOnce(
  300. &ProxyingWebSocket::OnMojoConnectionError, base::Unretained(this)));
  301. }
  302. void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) {
  303. if (error_code != net::OK) {
  304. OnError(error_code);
  305. return;
  306. }
  307. if (on_headers_received_callback_) {
  308. absl::optional<std::string> headers;
  309. if (override_headers_)
  310. headers = override_headers_->raw_headers();
  311. std::move(on_headers_received_callback_)
  312. .Run(net::OK, headers, absl::nullopt);
  313. }
  314. if (override_headers_) {
  315. response_->headers = override_headers_;
  316. override_headers_ = nullptr;
  317. }
  318. ResumeIncomingMethodCallProcessing();
  319. info_.AddResponseInfoFromResourceResponse(*response_);
  320. web_request_api_->OnResponseStarted(&info_, request_);
  321. if (!receiver_as_header_client_.is_bound())
  322. ContinueToCompleted();
  323. }
  324. void ProxyingWebSocket::OnAuthRequiredComplete(AuthRequiredResponse rv) {
  325. CHECK(auth_required_callback_);
  326. ResumeIncomingMethodCallProcessing();
  327. switch (rv) {
  328. case AuthRequiredResponse::kNoAction:
  329. case AuthRequiredResponse::kCancelAuth:
  330. std::move(auth_required_callback_).Run(absl::nullopt);
  331. break;
  332. case AuthRequiredResponse::kSetAuth:
  333. std::move(auth_required_callback_).Run(auth_credentials_);
  334. break;
  335. case AuthRequiredResponse::kIoPending:
  336. NOTREACHED();
  337. break;
  338. }
  339. }
  340. void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth(
  341. const net::AuthChallengeInfo& auth_info,
  342. int rv) {
  343. if (rv != net::OK) {
  344. OnError(rv);
  345. return;
  346. }
  347. ResumeIncomingMethodCallProcessing();
  348. info_.AddResponseInfoFromResourceResponse(*response_);
  349. auto continuation = base::BindRepeating(
  350. &ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr());
  351. auto auth_rv = AuthRequiredResponse::kIoPending;
  352. PauseIncomingMethodCallProcessing();
  353. OnAuthRequiredComplete(auth_rv);
  354. }
  355. void ProxyingWebSocket::PauseIncomingMethodCallProcessing() {
  356. receiver_as_handshake_client_.Pause();
  357. receiver_as_auth_handler_.Pause();
  358. if (receiver_as_header_client_.is_bound())
  359. receiver_as_header_client_.Pause();
  360. }
  361. void ProxyingWebSocket::ResumeIncomingMethodCallProcessing() {
  362. receiver_as_handshake_client_.Resume();
  363. receiver_as_auth_handler_.Resume();
  364. if (receiver_as_header_client_.is_bound())
  365. receiver_as_header_client_.Resume();
  366. }
  367. void ProxyingWebSocket::OnError(int error_code) {
  368. if (!is_done_) {
  369. is_done_ = true;
  370. web_request_api_->OnErrorOccurred(&info_, request_, error_code);
  371. }
  372. // Deletes |this|.
  373. delete this;
  374. }
  375. void ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason(
  376. uint32_t custom_reason,
  377. const std::string& description) {
  378. // Here we want to notify the custom reason to the client, which is why
  379. // we reset |forwarding_handshake_client_| manually.
  380. forwarding_handshake_client_.ResetWithReason(custom_reason, description);
  381. OnError(net::ERR_FAILED);
  382. // Deletes |this|.
  383. }
  384. void ProxyingWebSocket::OnMojoConnectionError() {
  385. OnError(net::ERR_FAILED);
  386. // Deletes |this|.
  387. }
  388. } // namespace electron