proxying_websocket.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. // Copyright 2020 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "electron/shell/browser/net/proxying_websocket.h"
  5. #include <utility>
  6. #include "base/bind.h"
  7. #include "base/strings/string_util.h"
  8. #include "base/strings/stringprintf.h"
  9. #include "components/keyed_service/content/browser_context_keyed_service_shutdown_notifier_factory.h"
  10. #include "content/public/browser/browser_thread.h"
  11. #include "extensions/browser/extension_navigation_ui_data.h"
  12. #include "net/base/ip_endpoint.h"
  13. #include "net/http/http_util.h"
  14. namespace electron {
  15. ProxyingWebSocket::ProxyingWebSocket(
  16. WebRequestAPI* web_request_api,
  17. WebSocketFactory factory,
  18. const network::ResourceRequest& request,
  19. mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
  20. handshake_client,
  21. bool has_extra_headers,
  22. int process_id,
  23. int render_frame_id,
  24. content::BrowserContext* browser_context,
  25. uint64_t* request_id_generator)
  26. : web_request_api_(web_request_api),
  27. request_(request),
  28. factory_(std::move(factory)),
  29. forwarding_handshake_client_(std::move(handshake_client)),
  30. request_headers_(request.headers),
  31. has_extra_headers_(has_extra_headers),
  32. info_(extensions::WebRequestInfoInitParams(
  33. ++(*request_id_generator),
  34. process_id,
  35. render_frame_id,
  36. nullptr,
  37. MSG_ROUTING_NONE,
  38. request,
  39. /*is_download=*/false,
  40. /*is_async=*/true,
  41. /*is_service_worker_script=*/false)) {}
  42. ProxyingWebSocket::~ProxyingWebSocket() {
  43. if (on_before_send_headers_callback_) {
  44. std::move(on_before_send_headers_callback_)
  45. .Run(net::ERR_ABORTED, base::nullopt);
  46. }
  47. if (on_headers_received_callback_) {
  48. std::move(on_headers_received_callback_)
  49. .Run(net::ERR_ABORTED, base::nullopt, GURL());
  50. }
  51. }
  52. void ProxyingWebSocket::Start() {
  53. // If the header client will be used, we start the request immediately, and
  54. // OnBeforeSendHeaders and OnSendHeaders will be handled there. Otherwise,
  55. // send these events before the request starts.
  56. base::RepeatingCallback<void(int)> continuation;
  57. if (has_extra_headers_) {
  58. continuation = base::BindRepeating(
  59. &ProxyingWebSocket::ContinueToStartRequest, weak_factory_.GetWeakPtr());
  60. } else {
  61. continuation =
  62. base::BindRepeating(&ProxyingWebSocket::OnBeforeRequestComplete,
  63. weak_factory_.GetWeakPtr());
  64. }
  65. int result = web_request_api_->OnBeforeRequest(&info_, request_, continuation,
  66. &redirect_url_);
  67. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  68. OnError(result);
  69. return;
  70. }
  71. if (result == net::ERR_IO_PENDING) {
  72. return;
  73. }
  74. DCHECK_EQ(net::OK, result);
  75. continuation.Run(net::OK);
  76. }
  77. void ProxyingWebSocket::OnOpeningHandshakeStarted(
  78. network::mojom::WebSocketHandshakeRequestPtr request) {
  79. DCHECK(forwarding_handshake_client_);
  80. forwarding_handshake_client_->OnOpeningHandshakeStarted(std::move(request));
  81. }
  82. void ProxyingWebSocket::OnResponseReceived(
  83. network::mojom::WebSocketHandshakeResponsePtr response) {
  84. DCHECK(forwarding_handshake_client_);
  85. // response_.headers will be set in OnBeforeSendHeaders if
  86. // |receiver_as_header_client_| is set.
  87. if (!receiver_as_header_client_.is_bound()) {
  88. response_.headers =
  89. base::MakeRefCounted<net::HttpResponseHeaders>(base::StringPrintf(
  90. "HTTP/%d.%d %d %s", response->http_version.major_value(),
  91. response->http_version.minor_value(), response->status_code,
  92. response->status_text.c_str()));
  93. for (const auto& header : response->headers)
  94. response_.headers->AddHeader(header->name + ": " + header->value);
  95. }
  96. response_.remote_endpoint = response->remote_endpoint;
  97. // TODO(yhirano): OnResponseReceived is called with the original
  98. // response headers. That means if OnHeadersReceived modified them the
  99. // renderer won't see that modification. This is the opposite of http(s)
  100. // requests.
  101. forwarding_handshake_client_->OnResponseReceived(std::move(response));
  102. if (!receiver_as_header_client_.is_bound() || response_.headers) {
  103. ContinueToHeadersReceived();
  104. } else {
  105. waiting_for_header_client_headers_received_ = true;
  106. }
  107. }
  108. void ProxyingWebSocket::ContinueToHeadersReceived() {
  109. auto continuation =
  110. base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedComplete,
  111. weak_factory_.GetWeakPtr());
  112. info_.AddResponseInfoFromResourceResponse(response_);
  113. int result = web_request_api_->OnHeadersReceived(
  114. &info_, request_, continuation, response_.headers.get(),
  115. &override_headers_, &redirect_url_);
  116. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  117. OnError(result);
  118. return;
  119. }
  120. PauseIncomingMethodCallProcessing();
  121. if (result == net::ERR_IO_PENDING)
  122. return;
  123. DCHECK_EQ(net::OK, result);
  124. OnHeadersReceivedComplete(net::OK);
  125. }
  126. void ProxyingWebSocket::OnConnectionEstablished(
  127. mojo::PendingRemote<network::mojom::WebSocket> websocket,
  128. mojo::PendingReceiver<network::mojom::WebSocketClient> client_receiver,
  129. const std::string& selected_protocol,
  130. const std::string& extensions,
  131. mojo::ScopedDataPipeConsumerHandle readable) {
  132. DCHECK(forwarding_handshake_client_);
  133. DCHECK(!is_done_);
  134. is_done_ = true;
  135. web_request_api_->OnCompleted(&info_, request_, net::ERR_WS_UPGRADE);
  136. forwarding_handshake_client_->OnConnectionEstablished(
  137. std::move(websocket), std::move(client_receiver), selected_protocol,
  138. extensions, std::move(readable));
  139. // Deletes |this|.
  140. delete this;
  141. }
  142. void ProxyingWebSocket::OnAuthRequired(
  143. const net::AuthChallengeInfo& auth_info,
  144. const scoped_refptr<net::HttpResponseHeaders>& headers,
  145. const net::IPEndPoint& remote_endpoint,
  146. OnAuthRequiredCallback callback) {
  147. if (!callback) {
  148. OnError(net::ERR_FAILED);
  149. return;
  150. }
  151. response_.headers = headers;
  152. response_.remote_endpoint = remote_endpoint;
  153. auth_required_callback_ = std::move(callback);
  154. auto continuation =
  155. base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedCompleteForAuth,
  156. weak_factory_.GetWeakPtr(), auth_info);
  157. info_.AddResponseInfoFromResourceResponse(response_);
  158. int result = web_request_api_->OnHeadersReceived(
  159. &info_, request_, continuation, response_.headers.get(),
  160. &override_headers_, &redirect_url_);
  161. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  162. OnError(result);
  163. return;
  164. }
  165. PauseIncomingMethodCallProcessing();
  166. if (result == net::ERR_IO_PENDING)
  167. return;
  168. DCHECK_EQ(net::OK, result);
  169. OnHeadersReceivedCompleteForAuth(auth_info, net::OK);
  170. }
  171. void ProxyingWebSocket::OnBeforeSendHeaders(
  172. const net::HttpRequestHeaders& headers,
  173. OnBeforeSendHeadersCallback callback) {
  174. DCHECK(receiver_as_header_client_.is_bound());
  175. request_headers_ = headers;
  176. on_before_send_headers_callback_ = std::move(callback);
  177. OnBeforeRequestComplete(net::OK);
  178. }
  179. void ProxyingWebSocket::OnHeadersReceived(const std::string& headers,
  180. OnHeadersReceivedCallback callback) {
  181. DCHECK(receiver_as_header_client_.is_bound());
  182. // Note: since there are different pipes used for WebSocketClient and
  183. // TrustedHeaderClient, there are no guarantees whether this or
  184. // OnResponseReceived are called first.
  185. on_headers_received_callback_ = std::move(callback);
  186. response_.headers = base::MakeRefCounted<net::HttpResponseHeaders>(headers);
  187. if (!waiting_for_header_client_headers_received_)
  188. return;
  189. waiting_for_header_client_headers_received_ = false;
  190. ContinueToHeadersReceived();
  191. }
  192. void ProxyingWebSocket::StartProxying(
  193. WebRequestAPI* web_request_api,
  194. WebSocketFactory factory,
  195. const GURL& url,
  196. const GURL& site_for_cookies,
  197. const base::Optional<std::string>& user_agent,
  198. mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
  199. handshake_client,
  200. bool has_extra_headers,
  201. int process_id,
  202. int render_frame_id,
  203. const url::Origin& origin,
  204. content::BrowserContext* browser_context,
  205. uint64_t* request_id_generator) {
  206. DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  207. network::ResourceRequest request;
  208. request.url = url;
  209. request.site_for_cookies = site_for_cookies;
  210. if (user_agent) {
  211. request.headers.SetHeader(net::HttpRequestHeaders::kUserAgent, *user_agent);
  212. }
  213. request.request_initiator = origin;
  214. auto* proxy = new ProxyingWebSocket(
  215. web_request_api, std::move(factory), request, std::move(handshake_client),
  216. has_extra_headers, process_id, render_frame_id, browser_context,
  217. request_id_generator);
  218. proxy->Start();
  219. }
  220. void ProxyingWebSocket::OnBeforeRequestComplete(int error_code) {
  221. DCHECK(receiver_as_header_client_.is_bound() ||
  222. !receiver_as_handshake_client_.is_bound());
  223. DCHECK(info_.url.SchemeIsWSOrWSS());
  224. if (error_code != net::OK) {
  225. OnError(error_code);
  226. return;
  227. }
  228. auto continuation =
  229. base::BindRepeating(&ProxyingWebSocket::OnBeforeSendHeadersComplete,
  230. weak_factory_.GetWeakPtr());
  231. info_.AddResponseInfoFromResourceResponse(response_);
  232. int result = web_request_api_->OnBeforeSendHeaders(
  233. &info_, request_, continuation, &request_headers_);
  234. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  235. OnError(result);
  236. return;
  237. }
  238. if (result == net::ERR_IO_PENDING)
  239. return;
  240. DCHECK_EQ(net::OK, result);
  241. OnBeforeSendHeadersComplete(std::set<std::string>(), std::set<std::string>(),
  242. net::OK);
  243. }
  244. void ProxyingWebSocket::OnBeforeSendHeadersComplete(
  245. const std::set<std::string>& removed_headers,
  246. const std::set<std::string>& set_headers,
  247. int error_code) {
  248. DCHECK(receiver_as_header_client_.is_bound() ||
  249. !receiver_as_handshake_client_.is_bound());
  250. if (error_code != net::OK) {
  251. OnError(error_code);
  252. return;
  253. }
  254. if (receiver_as_header_client_.is_bound()) {
  255. CHECK(on_before_send_headers_callback_);
  256. std::move(on_before_send_headers_callback_)
  257. .Run(error_code, request_headers_);
  258. }
  259. info_.AddResponseInfoFromResourceResponse(response_);
  260. web_request_api_->OnSendHeaders(&info_, request_, request_headers_);
  261. if (!receiver_as_header_client_.is_bound())
  262. ContinueToStartRequest(net::OK);
  263. }
  264. void ProxyingWebSocket::ContinueToStartRequest(int error_code) {
  265. if (error_code != net::OK) {
  266. OnError(error_code);
  267. return;
  268. }
  269. base::flat_set<std::string> used_header_names;
  270. std::vector<network::mojom::HttpHeaderPtr> additional_headers;
  271. for (net::HttpRequestHeaders::Iterator it(request_headers_); it.GetNext();) {
  272. additional_headers.push_back(
  273. network::mojom::HttpHeader::New(it.name(), it.value()));
  274. used_header_names.insert(base::ToLowerASCII(it.name()));
  275. }
  276. for (const auto& header : additional_headers_) {
  277. if (!used_header_names.contains(base::ToLowerASCII(header->name))) {
  278. additional_headers.push_back(
  279. network::mojom::HttpHeader::New(header->name, header->value));
  280. }
  281. }
  282. mojo::PendingRemote<network::mojom::TrustedHeaderClient>
  283. trusted_header_client = mojo::NullRemote();
  284. if (has_extra_headers_) {
  285. trusted_header_client =
  286. receiver_as_header_client_.BindNewPipeAndPassRemote();
  287. }
  288. std::move(factory_).Run(
  289. info_.url, std::move(additional_headers),
  290. receiver_as_handshake_client_.BindNewPipeAndPassRemote(),
  291. receiver_as_auth_handler_.BindNewPipeAndPassRemote(),
  292. std::move(trusted_header_client));
  293. // Here we detect mojo connection errors on |receiver_as_handshake_client_|.
  294. // See also CreateWebSocket in
  295. // //network/services/public/mojom/network_context.mojom.
  296. receiver_as_handshake_client_.set_disconnect_with_reason_handler(
  297. base::BindOnce(&ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason,
  298. base::Unretained(this)));
  299. forwarding_handshake_client_.set_disconnect_handler(base::BindOnce(
  300. &ProxyingWebSocket::OnMojoConnectionError, base::Unretained(this)));
  301. }
  302. void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) {
  303. if (error_code != net::OK) {
  304. OnError(error_code);
  305. return;
  306. }
  307. if (on_headers_received_callback_) {
  308. base::Optional<std::string> headers;
  309. if (override_headers_)
  310. headers = override_headers_->raw_headers();
  311. std::move(on_headers_received_callback_).Run(net::OK, headers, GURL());
  312. }
  313. if (override_headers_) {
  314. response_.headers = override_headers_;
  315. override_headers_ = nullptr;
  316. }
  317. ResumeIncomingMethodCallProcessing();
  318. info_.AddResponseInfoFromResourceResponse(response_);
  319. web_request_api_->OnResponseStarted(&info_, request_);
  320. }
  321. void ProxyingWebSocket::OnAuthRequiredComplete(
  322. net::NetworkDelegate::AuthRequiredResponse rv) {
  323. CHECK(auth_required_callback_);
  324. ResumeIncomingMethodCallProcessing();
  325. switch (rv) {
  326. case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_NO_ACTION:
  327. case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_CANCEL_AUTH:
  328. std::move(auth_required_callback_).Run(base::nullopt);
  329. break;
  330. case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_SET_AUTH:
  331. std::move(auth_required_callback_).Run(auth_credentials_);
  332. break;
  333. case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_IO_PENDING:
  334. NOTREACHED();
  335. break;
  336. }
  337. }
  338. void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth(
  339. const net::AuthChallengeInfo& auth_info,
  340. int rv) {
  341. if (rv != net::OK) {
  342. OnError(rv);
  343. return;
  344. }
  345. ResumeIncomingMethodCallProcessing();
  346. info_.AddResponseInfoFromResourceResponse(response_);
  347. auto continuation = base::BindRepeating(
  348. &ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr());
  349. auto auth_rv = net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_IO_PENDING;
  350. PauseIncomingMethodCallProcessing();
  351. OnAuthRequiredComplete(auth_rv);
  352. }
  353. void ProxyingWebSocket::PauseIncomingMethodCallProcessing() {
  354. receiver_as_handshake_client_.Pause();
  355. receiver_as_auth_handler_.Pause();
  356. if (receiver_as_header_client_.is_bound())
  357. receiver_as_header_client_.Pause();
  358. }
  359. void ProxyingWebSocket::ResumeIncomingMethodCallProcessing() {
  360. receiver_as_handshake_client_.Resume();
  361. receiver_as_auth_handler_.Resume();
  362. if (receiver_as_header_client_.is_bound())
  363. receiver_as_header_client_.Resume();
  364. }
  365. void ProxyingWebSocket::OnError(int error_code) {
  366. if (!is_done_) {
  367. is_done_ = true;
  368. web_request_api_->OnErrorOccurred(&info_, request_, error_code);
  369. }
  370. // Deletes |this|.
  371. delete this;
  372. }
  373. void ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason(
  374. uint32_t custom_reason,
  375. const std::string& description) {
  376. // Here we want to nofiy the custom reason to the client, which is why
  377. // we reset |forwarding_handshake_client_| manually.
  378. forwarding_handshake_client_.ResetWithReason(custom_reason, description);
  379. OnError(net::ERR_FAILED);
  380. // Deletes |this|.
  381. }
  382. void ProxyingWebSocket::OnMojoConnectionError() {
  383. OnError(net::ERR_FAILED);
  384. // Deletes |this|.
  385. }
  386. } // namespace electron