proxying_websocket.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. // Copyright 2020 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "electron/shell/browser/net/proxying_websocket.h"
  5. #include <utility>
  6. #include "base/bind.h"
  7. #include "base/strings/string_util.h"
  8. #include "base/strings/stringprintf.h"
  9. #include "content/public/browser/browser_thread.h"
  10. #include "extensions/browser/extension_navigation_ui_data.h"
  11. #include "net/base/ip_endpoint.h"
  12. #include "net/http/http_util.h"
  13. #include "services/network/public/mojom/url_response_head.mojom.h"
  14. namespace electron {
  15. ProxyingWebSocket::ProxyingWebSocket(
  16. WebRequestAPI* web_request_api,
  17. WebSocketFactory factory,
  18. const network::ResourceRequest& request,
  19. mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
  20. handshake_client,
  21. bool has_extra_headers,
  22. int process_id,
  23. int render_frame_id,
  24. content::BrowserContext* browser_context,
  25. uint64_t* request_id_generator)
  26. : web_request_api_(web_request_api),
  27. request_(request),
  28. factory_(std::move(factory)),
  29. forwarding_handshake_client_(std::move(handshake_client)),
  30. request_headers_(request.headers),
  31. response_(network::mojom::URLResponseHead::New()),
  32. has_extra_headers_(has_extra_headers),
  33. info_(extensions::WebRequestInfoInitParams(
  34. ++(*request_id_generator),
  35. process_id,
  36. render_frame_id,
  37. nullptr,
  38. MSG_ROUTING_NONE,
  39. request,
  40. /*is_download=*/false,
  41. /*is_async=*/true,
  42. /*is_service_worker_script=*/false,
  43. /*navigation_id=*/absl::nullopt,
  44. /*ukm_source_id=*/ukm::kInvalidSourceIdObj)) {}
  45. ProxyingWebSocket::~ProxyingWebSocket() {
  46. if (on_before_send_headers_callback_) {
  47. std::move(on_before_send_headers_callback_)
  48. .Run(net::ERR_ABORTED, absl::nullopt);
  49. }
  50. if (on_headers_received_callback_) {
  51. std::move(on_headers_received_callback_)
  52. .Run(net::ERR_ABORTED, absl::nullopt, GURL());
  53. }
  54. }
  55. void ProxyingWebSocket::Start() {
  56. // If the header client will be used, we start the request immediately, and
  57. // OnBeforeSendHeaders and OnSendHeaders will be handled there. Otherwise,
  58. // send these events before the request starts.
  59. base::RepeatingCallback<void(int)> continuation;
  60. if (has_extra_headers_) {
  61. continuation = base::BindRepeating(
  62. &ProxyingWebSocket::ContinueToStartRequest, weak_factory_.GetWeakPtr());
  63. } else {
  64. continuation =
  65. base::BindRepeating(&ProxyingWebSocket::OnBeforeRequestComplete,
  66. weak_factory_.GetWeakPtr());
  67. }
  68. int result = web_request_api_->OnBeforeRequest(&info_, request_, continuation,
  69. &redirect_url_);
  70. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  71. OnError(result);
  72. return;
  73. }
  74. if (result == net::ERR_IO_PENDING) {
  75. return;
  76. }
  77. DCHECK_EQ(net::OK, result);
  78. continuation.Run(net::OK);
  79. }
  80. void ProxyingWebSocket::OnOpeningHandshakeStarted(
  81. network::mojom::WebSocketHandshakeRequestPtr request) {
  82. DCHECK(forwarding_handshake_client_);
  83. forwarding_handshake_client_->OnOpeningHandshakeStarted(std::move(request));
  84. }
  85. void ProxyingWebSocket::ContinueToHeadersReceived() {
  86. auto continuation =
  87. base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedComplete,
  88. weak_factory_.GetWeakPtr());
  89. info_.AddResponseInfoFromResourceResponse(*response_);
  90. int result = web_request_api_->OnHeadersReceived(
  91. &info_, request_, continuation, response_->headers.get(),
  92. &override_headers_, &redirect_url_);
  93. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  94. OnError(result);
  95. return;
  96. }
  97. PauseIncomingMethodCallProcessing();
  98. if (result == net::ERR_IO_PENDING)
  99. return;
  100. DCHECK_EQ(net::OK, result);
  101. OnHeadersReceivedComplete(net::OK);
  102. }
  103. void ProxyingWebSocket::OnFailure(const std::string& message,
  104. int32_t net_error,
  105. int32_t response_code) {}
  106. void ProxyingWebSocket::OnConnectionEstablished(
  107. mojo::PendingRemote<network::mojom::WebSocket> websocket,
  108. mojo::PendingReceiver<network::mojom::WebSocketClient> client_receiver,
  109. network::mojom::WebSocketHandshakeResponsePtr response,
  110. mojo::ScopedDataPipeConsumerHandle readable,
  111. mojo::ScopedDataPipeProducerHandle writable) {
  112. DCHECK(forwarding_handshake_client_);
  113. DCHECK(!is_done_);
  114. is_done_ = true;
  115. websocket_ = std::move(websocket);
  116. client_receiver_ = std::move(client_receiver);
  117. handshake_response_ = std::move(response);
  118. readable_ = std::move(readable);
  119. writable_ = std::move(writable);
  120. response_->remote_endpoint = handshake_response_->remote_endpoint;
  121. // response_->headers will be set in OnBeforeSendHeaders if
  122. // |receiver_as_header_client_| is set.
  123. if (receiver_as_header_client_.is_bound()) {
  124. ContinueToCompleted();
  125. return;
  126. }
  127. response_->headers =
  128. base::MakeRefCounted<net::HttpResponseHeaders>(base::StringPrintf(
  129. "HTTP/%d.%d %d %s", handshake_response_->http_version.major_value(),
  130. handshake_response_->http_version.minor_value(),
  131. handshake_response_->status_code,
  132. handshake_response_->status_text.c_str()));
  133. for (const auto& header : handshake_response_->headers)
  134. response_->headers->AddHeader(header->name, header->value);
  135. ContinueToHeadersReceived();
  136. }
  137. void ProxyingWebSocket::ContinueToCompleted() {
  138. DCHECK(forwarding_handshake_client_);
  139. DCHECK(is_done_);
  140. web_request_api_->OnCompleted(&info_, request_, net::ERR_WS_UPGRADE);
  141. forwarding_handshake_client_->OnConnectionEstablished(
  142. std::move(websocket_), std::move(client_receiver_),
  143. std::move(handshake_response_), std::move(readable_),
  144. std::move(writable_));
  145. // Deletes |this|.
  146. delete this;
  147. }
  148. void ProxyingWebSocket::OnAuthRequired(
  149. const net::AuthChallengeInfo& auth_info,
  150. const scoped_refptr<net::HttpResponseHeaders>& headers,
  151. const net::IPEndPoint& remote_endpoint,
  152. OnAuthRequiredCallback callback) {
  153. if (!callback) {
  154. OnError(net::ERR_FAILED);
  155. return;
  156. }
  157. response_->headers = headers;
  158. response_->remote_endpoint = remote_endpoint;
  159. auth_required_callback_ = std::move(callback);
  160. auto continuation =
  161. base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedCompleteForAuth,
  162. weak_factory_.GetWeakPtr(), auth_info);
  163. info_.AddResponseInfoFromResourceResponse(*response_);
  164. int result = web_request_api_->OnHeadersReceived(
  165. &info_, request_, continuation, response_->headers.get(),
  166. &override_headers_, &redirect_url_);
  167. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  168. OnError(result);
  169. return;
  170. }
  171. PauseIncomingMethodCallProcessing();
  172. if (result == net::ERR_IO_PENDING)
  173. return;
  174. DCHECK_EQ(net::OK, result);
  175. OnHeadersReceivedCompleteForAuth(auth_info, net::OK);
  176. }
  177. void ProxyingWebSocket::OnBeforeSendHeaders(
  178. const net::HttpRequestHeaders& headers,
  179. OnBeforeSendHeadersCallback callback) {
  180. DCHECK(receiver_as_header_client_.is_bound());
  181. request_headers_ = headers;
  182. on_before_send_headers_callback_ = std::move(callback);
  183. OnBeforeRequestComplete(net::OK);
  184. }
  185. void ProxyingWebSocket::OnHeadersReceived(const std::string& headers,
  186. const net::IPEndPoint& endpoint,
  187. OnHeadersReceivedCallback callback) {
  188. DCHECK(receiver_as_header_client_.is_bound());
  189. on_headers_received_callback_ = std::move(callback);
  190. response_->headers = base::MakeRefCounted<net::HttpResponseHeaders>(headers);
  191. ContinueToHeadersReceived();
  192. }
  193. void ProxyingWebSocket::StartProxying(
  194. WebRequestAPI* web_request_api,
  195. WebSocketFactory factory,
  196. const GURL& url,
  197. const net::SiteForCookies& site_for_cookies,
  198. const absl::optional<std::string>& user_agent,
  199. mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
  200. handshake_client,
  201. bool has_extra_headers,
  202. int process_id,
  203. int render_frame_id,
  204. const url::Origin& origin,
  205. content::BrowserContext* browser_context,
  206. uint64_t* request_id_generator) {
  207. DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  208. network::ResourceRequest request;
  209. request.url = url;
  210. request.site_for_cookies = site_for_cookies;
  211. if (user_agent) {
  212. request.headers.SetHeader(net::HttpRequestHeaders::kUserAgent, *user_agent);
  213. }
  214. request.request_initiator = origin;
  215. auto* proxy = new ProxyingWebSocket(
  216. web_request_api, std::move(factory), request, std::move(handshake_client),
  217. has_extra_headers, process_id, render_frame_id, browser_context,
  218. request_id_generator);
  219. proxy->Start();
  220. }
  221. void ProxyingWebSocket::OnBeforeRequestComplete(int error_code) {
  222. DCHECK(receiver_as_header_client_.is_bound() ||
  223. !receiver_as_handshake_client_.is_bound());
  224. DCHECK(info_.url.SchemeIsWSOrWSS());
  225. if (error_code != net::OK) {
  226. OnError(error_code);
  227. return;
  228. }
  229. auto continuation =
  230. base::BindRepeating(&ProxyingWebSocket::OnBeforeSendHeadersComplete,
  231. weak_factory_.GetWeakPtr());
  232. info_.AddResponseInfoFromResourceResponse(*response_);
  233. int result = web_request_api_->OnBeforeSendHeaders(
  234. &info_, request_, continuation, &request_headers_);
  235. if (result == net::ERR_BLOCKED_BY_CLIENT) {
  236. OnError(result);
  237. return;
  238. }
  239. if (result == net::ERR_IO_PENDING)
  240. return;
  241. DCHECK_EQ(net::OK, result);
  242. OnBeforeSendHeadersComplete(std::set<std::string>(), std::set<std::string>(),
  243. net::OK);
  244. }
  245. void ProxyingWebSocket::OnBeforeSendHeadersComplete(
  246. const std::set<std::string>& removed_headers,
  247. const std::set<std::string>& set_headers,
  248. int error_code) {
  249. DCHECK(receiver_as_header_client_.is_bound() ||
  250. !receiver_as_handshake_client_.is_bound());
  251. if (error_code != net::OK) {
  252. OnError(error_code);
  253. return;
  254. }
  255. if (receiver_as_header_client_.is_bound()) {
  256. CHECK(on_before_send_headers_callback_);
  257. std::move(on_before_send_headers_callback_)
  258. .Run(error_code, request_headers_);
  259. }
  260. info_.AddResponseInfoFromResourceResponse(*response_);
  261. web_request_api_->OnSendHeaders(&info_, request_, request_headers_);
  262. if (!receiver_as_header_client_.is_bound())
  263. ContinueToStartRequest(net::OK);
  264. }
  265. void ProxyingWebSocket::ContinueToStartRequest(int error_code) {
  266. if (error_code != net::OK) {
  267. OnError(error_code);
  268. return;
  269. }
  270. base::flat_set<std::string> used_header_names;
  271. std::vector<network::mojom::HttpHeaderPtr> additional_headers;
  272. for (net::HttpRequestHeaders::Iterator it(request_headers_); it.GetNext();) {
  273. additional_headers.push_back(
  274. network::mojom::HttpHeader::New(it.name(), it.value()));
  275. used_header_names.insert(base::ToLowerASCII(it.name()));
  276. }
  277. for (const auto& header : additional_headers_) {
  278. if (!used_header_names.contains(base::ToLowerASCII(header->name))) {
  279. additional_headers.push_back(
  280. network::mojom::HttpHeader::New(header->name, header->value));
  281. }
  282. }
  283. mojo::PendingRemote<network::mojom::TrustedHeaderClient>
  284. trusted_header_client = mojo::NullRemote();
  285. if (has_extra_headers_) {
  286. trusted_header_client =
  287. receiver_as_header_client_.BindNewPipeAndPassRemote();
  288. }
  289. std::move(factory_).Run(
  290. info_.url, std::move(additional_headers),
  291. receiver_as_handshake_client_.BindNewPipeAndPassRemote(),
  292. receiver_as_auth_handler_.BindNewPipeAndPassRemote(),
  293. std::move(trusted_header_client));
  294. // Here we detect mojo connection errors on |receiver_as_handshake_client_|.
  295. // See also CreateWebSocket in
  296. // //network/services/public/mojom/network_context.mojom.
  297. receiver_as_handshake_client_.set_disconnect_with_reason_handler(
  298. base::BindOnce(&ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason,
  299. base::Unretained(this)));
  300. forwarding_handshake_client_.set_disconnect_handler(base::BindOnce(
  301. &ProxyingWebSocket::OnMojoConnectionError, base::Unretained(this)));
  302. }
  303. void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) {
  304. if (error_code != net::OK) {
  305. OnError(error_code);
  306. return;
  307. }
  308. if (on_headers_received_callback_) {
  309. absl::optional<std::string> headers;
  310. if (override_headers_)
  311. headers = override_headers_->raw_headers();
  312. std::move(on_headers_received_callback_)
  313. .Run(net::OK, headers, absl::nullopt);
  314. }
  315. if (override_headers_) {
  316. response_->headers = override_headers_;
  317. override_headers_ = nullptr;
  318. }
  319. ResumeIncomingMethodCallProcessing();
  320. info_.AddResponseInfoFromResourceResponse(*response_);
  321. web_request_api_->OnResponseStarted(&info_, request_);
  322. if (!receiver_as_header_client_.is_bound())
  323. ContinueToCompleted();
  324. }
  325. void ProxyingWebSocket::OnAuthRequiredComplete(AuthRequiredResponse rv) {
  326. CHECK(auth_required_callback_);
  327. ResumeIncomingMethodCallProcessing();
  328. switch (rv) {
  329. case AuthRequiredResponse::kNoAction:
  330. case AuthRequiredResponse::kCancelAuth:
  331. std::move(auth_required_callback_).Run(absl::nullopt);
  332. break;
  333. case AuthRequiredResponse::kSetAuth:
  334. std::move(auth_required_callback_).Run(auth_credentials_);
  335. break;
  336. case AuthRequiredResponse::kIoPending:
  337. NOTREACHED();
  338. break;
  339. }
  340. }
  341. void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth(
  342. const net::AuthChallengeInfo& auth_info,
  343. int rv) {
  344. if (rv != net::OK) {
  345. OnError(rv);
  346. return;
  347. }
  348. ResumeIncomingMethodCallProcessing();
  349. info_.AddResponseInfoFromResourceResponse(*response_);
  350. auto continuation = base::BindRepeating(
  351. &ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr());
  352. auto auth_rv = AuthRequiredResponse::kIoPending;
  353. PauseIncomingMethodCallProcessing();
  354. OnAuthRequiredComplete(auth_rv);
  355. }
  356. void ProxyingWebSocket::PauseIncomingMethodCallProcessing() {
  357. receiver_as_handshake_client_.Pause();
  358. receiver_as_auth_handler_.Pause();
  359. if (receiver_as_header_client_.is_bound())
  360. receiver_as_header_client_.Pause();
  361. }
  362. void ProxyingWebSocket::ResumeIncomingMethodCallProcessing() {
  363. receiver_as_handshake_client_.Resume();
  364. receiver_as_auth_handler_.Resume();
  365. if (receiver_as_header_client_.is_bound())
  366. receiver_as_header_client_.Resume();
  367. }
  368. void ProxyingWebSocket::OnError(int error_code) {
  369. if (!is_done_) {
  370. is_done_ = true;
  371. web_request_api_->OnErrorOccurred(&info_, request_, error_code);
  372. }
  373. // Deletes |this|.
  374. delete this;
  375. }
  376. void ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason(
  377. uint32_t custom_reason,
  378. const std::string& description) {
  379. // Here we want to notify the custom reason to the client, which is why
  380. // we reset |forwarding_handshake_client_| manually.
  381. forwarding_handshake_client_.ResetWithReason(custom_reason, description);
  382. OnError(net::ERR_FAILED);
  383. // Deletes |this|.
  384. }
  385. void ProxyingWebSocket::OnMojoConnectionError() {
  386. OnError(net::ERR_FAILED);
  387. // Deletes |this|.
  388. }
  389. } // namespace electron