Browse Source

fix: make webRequest work with WebSocket (7-1-x) (#22141)

* fix: make webRequest work with WebSocket

* test: install modules for spec-main
Cheng Zhao 5 years ago
parent
commit
40a212e132

+ 2 - 0
docs/api/web-request.md

@@ -146,6 +146,7 @@ response are visible by the time this listener is fired.
     * `timestamp` Double
     * `statusLine` String
     * `statusCode` Integer
+    * `requestHeaders` Record<string, string>
     * `responseHeaders` Record<string, string[]> (optional)
   * `callback` Function
     * `headersReceivedResponse` Object
@@ -228,6 +229,7 @@ redirect is about to occur.
     * `fromCache` Boolean
     * `statusCode` Integer
     * `statusLine` String
+    * `error` String
 
 The `listener` will be called with `listener(details)` when a request is
 completed.

+ 3 - 0
filenames.gni

@@ -221,6 +221,8 @@ filenames = {
     "shell/browser/net/cert_verifier_client.h",
     "shell/browser/net/proxying_url_loader_factory.cc",
     "shell/browser/net/proxying_url_loader_factory.h",
+    "shell/browser/net/proxying_websocket.cc",
+    "shell/browser/net/proxying_websocket.h",
     "shell/browser/net/network_context_service_factory.cc",
     "shell/browser/net/network_context_service_factory.h",
     "shell/browser/net/network_context_service.cc",
@@ -233,6 +235,7 @@ filenames = {
     "shell/browser/net/system_network_context_manager.h",
     "shell/browser/net/url_pipe_loader.cc",
     "shell/browser/net/url_pipe_loader.h",
+    "shell/browser/net/web_request_api_interface.h",
     "shell/browser/notifications/linux/libnotify_notification.cc",
     "shell/browser/notifications/linux/libnotify_notification.h",
     "shell/browser/notifications/linux/notification_presenter_linux.cc",

+ 8 - 4
script/spec-runner.js

@@ -45,7 +45,8 @@ async function main () {
       (lastSpecInstallHash !== currentSpecInstallHash)
 
   if (somethingChanged) {
-    await installSpecModules()
+    await installSpecModules(path.resolve(__dirname, '..', 'spec'))
+    await installSpecModules(path.resolve(__dirname, '..', 'spec-main'))
     await getSpecHash().then(saveSpecHash)
   }
 
@@ -140,7 +141,7 @@ async function runMainProcessElectronTests () {
   }
 }
 
-async function installSpecModules () {
+async function installSpecModules (dir) {
   const nodeDir = path.resolve(BASE, `out/${utils.OUT_DIR}/gen/node_headers`)
   const env = Object.assign({}, process.env, {
     npm_config_nodedir: nodeDir,
@@ -148,11 +149,12 @@ async function installSpecModules () {
   })
   const { status } = childProcess.spawnSync(NPX_CMD, [`yarn@${YARN_VERSION}`, 'install', '--frozen-lockfile'], {
     env,
-    cwd: path.resolve(__dirname, '../spec'),
+    cwd: dir,
     stdio: 'inherit'
   })
   if (status !== 0 && !process.env.IGNORE_YARN_INSTALL_ERROR) {
-    throw new Error('Failed to yarn install in the spec folder')
+    console.log(`Failed to yarn install in '${dir}'`)
+    process.exit(1)
   }
 }
 
@@ -161,7 +163,9 @@ function getSpecHash () {
     (async () => {
       const hasher = crypto.createHash('SHA256')
       hasher.update(fs.readFileSync(path.resolve(__dirname, '../spec/package.json')))
+      hasher.update(fs.readFileSync(path.resolve(__dirname, '../spec-main/package.json')))
       hasher.update(fs.readFileSync(path.resolve(__dirname, '../spec/yarn.lock')))
+      hasher.update(fs.readFileSync(path.resolve(__dirname, '../spec-main/yarn.lock')))
       return hasher.digest('hex')
     })(),
     (async () => {

+ 5 - 5
shell/browser/api/atom_api_web_request_ns.h

@@ -15,7 +15,7 @@
 #include "gin/wrappable.h"
 #include "native_mate/dictionary.h"
 #include "native_mate/handle.h"
-#include "shell/browser/net/proxying_url_loader_factory.h"
+#include "shell/browser/net/web_request_api_interface.h"
 
 namespace content {
 class BrowserContext;
@@ -52,10 +52,6 @@ class WebRequestNS : public gin::Wrappable<WebRequestNS>, public WebRequestAPI {
       v8::Isolate* isolate) override;
   const char* GetTypeName() override;
 
- private:
-  WebRequestNS(v8::Isolate* isolate, content::BrowserContext* browser_context);
-  ~WebRequestNS() override;
-
   // WebRequestAPI:
   bool HasListener() const override;
   int OnBeforeRequest(extensions::WebRequestInfo* info,
@@ -89,6 +85,10 @@ class WebRequestNS : public gin::Wrappable<WebRequestNS>, public WebRequestAPI {
                    int net_error) override;
   void OnRequestWillBeDestroyed(extensions::WebRequestInfo* info) override;
 
+ private:
+  WebRequestNS(v8::Isolate* isolate, content::BrowserContext* browser_context);
+  ~WebRequestNS() override;
+
   enum SimpleEvent {
     kOnSendHeaders,
     kOnBeforeRedirect,

+ 38 - 1
shell/browser/atom_browser_client.cc

@@ -67,6 +67,7 @@
 #include "shell/browser/net/network_context_service.h"
 #include "shell/browser/net/network_context_service_factory.h"
 #include "shell/browser/net/proxying_url_loader_factory.h"
+#include "shell/browser/net/proxying_websocket.h"
 #include "shell/browser/net/system_network_context_manager.h"
 #include "shell/browser/notifications/notification_presenter.h"
 #include "shell/browser/notifications/platform_notification_service.h"
@@ -982,6 +983,42 @@ void AtomBrowserClient::RegisterNonNetworkSubresourceURLLoaderFactories(
   }
 }
 
+bool AtomBrowserClient::WillInterceptWebSocket(
+    content::RenderFrameHost* frame) {
+  if (!frame)
+    return false;
+
+  v8::Isolate* isolate = v8::Isolate::GetCurrent();
+  auto* browser_context = frame->GetProcess()->GetBrowserContext();
+  auto web_request = api::WebRequestNS::FromOrCreate(isolate, browser_context);
+
+  // NOTE: Some unit test environments do not initialize
+  // BrowserContextKeyedAPI factories for e.g. WebRequest.
+  if (!web_request.get())
+    return false;
+
+  return web_request->HasListener();
+}
+
+void AtomBrowserClient::CreateWebSocket(
+    content::RenderFrameHost* frame,
+    WebSocketFactory factory,
+    const GURL& url,
+    const GURL& site_for_cookies,
+    const base::Optional<std::string>& user_agent,
+    mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
+        handshake_client) {
+  v8::Isolate* isolate = v8::Isolate::GetCurrent();
+  auto* browser_context = frame->GetProcess()->GetBrowserContext();
+  auto web_request = api::WebRequestNS::FromOrCreate(isolate, browser_context);
+  DCHECK(web_request.get());
+  ProxyingWebSocket::StartProxying(
+      web_request.get(), std::move(factory), url, site_for_cookies, user_agent,
+      std::move(handshake_client), true, frame->GetProcess()->GetID(),
+      frame->GetRoutingID(), frame->GetLastCommittedOrigin(), browser_context,
+      &next_id_);
+}
+
 bool AtomBrowserClient::WillCreateURLLoaderFactory(
     content::BrowserContext* browser_context,
     content::RenderFrameHost* frame_host,
@@ -1010,7 +1047,7 @@ bool AtomBrowserClient::WillCreateURLLoaderFactory(
 
   new ProxyingURLLoaderFactory(
       web_request.get(), protocol->intercept_handlers(), render_process_id,
-      std::move(proxied_receiver), std::move(target_factory_info),
+      &next_id_, std::move(proxied_receiver), std::move(target_factory_info),
       std::move(header_client_receiver), type);
 
   if (bypass_redirect_checks)

+ 13 - 0
shell/browser/atom_browser_client.h

@@ -170,6 +170,15 @@ class AtomBrowserClient : public content::ContentBrowserClient,
       int render_process_id,
       int render_frame_id,
       NonNetworkURLLoaderFactoryMap* factories) override;
+  void CreateWebSocket(
+      content::RenderFrameHost* frame,
+      WebSocketFactory factory,
+      const GURL& url,
+      const GURL& site_for_cookies,
+      const base::Optional<std::string>& user_agent,
+      mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
+          handshake_client) override;
+  bool WillInterceptWebSocket(content::RenderFrameHost*) override;
   bool WillCreateURLLoaderFactory(
       content::BrowserContext* browser_context,
       content::RenderFrameHost* frame,
@@ -274,6 +283,10 @@ class AtomBrowserClient : public content::ContentBrowserClient,
 
   bool disable_process_restart_tricks_ = false;
 
+  // Simple shared ID generator, used by ProxyingURLLoaderFactory and
+  // ProxyingWebSocket classes.
+  uint64_t next_id_ = 0;
+
   DISALLOW_COPY_AND_ASSIGN(AtomBrowserClient);
 };
 

+ 3 - 8
shell/browser/net/proxying_url_loader_factory.cc

@@ -18,13 +18,6 @@
 #include "shell/common/options_switches.h"
 
 namespace electron {
-
-namespace {
-
-int64_t g_request_id = 0;
-
-}  // namespace
-
 ProxyingURLLoaderFactory::InProgressRequest::FollowRedirectParams::
     FollowRedirectParams() = default;
 ProxyingURLLoaderFactory::InProgressRequest::FollowRedirectParams::
@@ -668,6 +661,7 @@ ProxyingURLLoaderFactory::ProxyingURLLoaderFactory(
     WebRequestAPI* web_request_api,
     const HandlersMap& intercepted_handlers,
     int render_process_id,
+    uint64_t* request_id_generator,
     network::mojom::URLLoaderFactoryRequest loader_request,
     network::mojom::URLLoaderFactoryPtrInfo target_factory_info,
     mojo::PendingReceiver<network::mojom::TrustedURLLoaderHeaderClient>
@@ -676,6 +670,7 @@ ProxyingURLLoaderFactory::ProxyingURLLoaderFactory(
     : web_request_api_(web_request_api),
       intercepted_handlers_(intercepted_handlers),
       render_process_id_(render_process_id),
+      request_id_generator_(request_id_generator),
       loader_factory_type_(loader_factory_type) {
   target_factory_.Bind(std::move(target_factory_info));
   target_factory_.set_connection_error_handler(base::BindOnce(
@@ -751,7 +746,7 @@ void ProxyingURLLoaderFactory::CreateLoaderAndStart(
   // per-BrowserContext so extensions can make sense of it.  Note that
   // |network_service_request_id_| by contrast is not necessarily unique, so we
   // don't use it for identity here.
-  const uint64_t web_request_id = ++g_request_id;
+  const uint64_t web_request_id = ++(*request_id_generator_);
 
   // Notes: Chromium assumes that requests with zero-ID would never use the
   // "extraHeaders" code path, however in Electron requests started from

+ 3 - 43
shell/browser/net/proxying_url_loader_factory.h

@@ -20,52 +20,10 @@
 #include "services/network/public/mojom/network_context.mojom.h"
 #include "services/network/public/mojom/url_loader.mojom.h"
 #include "shell/browser/net/atom_url_loader_factory.h"
+#include "shell/browser/net/web_request_api_interface.h"
 
 namespace electron {
 
-// Defines the interface for WebRequest API, implemented by api::WebRequestNS.
-class WebRequestAPI {
- public:
-  virtual ~WebRequestAPI() {}
-
-  using BeforeSendHeadersCallback =
-      base::OnceCallback<void(const std::set<std::string>& removed_headers,
-                              const std::set<std::string>& set_headers,
-                              int error_code)>;
-
-  virtual bool HasListener() const = 0;
-  virtual int OnBeforeRequest(extensions::WebRequestInfo* info,
-                              const network::ResourceRequest& request,
-                              net::CompletionOnceCallback callback,
-                              GURL* new_url) = 0;
-  virtual int OnBeforeSendHeaders(extensions::WebRequestInfo* info,
-                                  const network::ResourceRequest& request,
-                                  BeforeSendHeadersCallback callback,
-                                  net::HttpRequestHeaders* headers) = 0;
-  virtual int OnHeadersReceived(
-      extensions::WebRequestInfo* info,
-      const network::ResourceRequest& request,
-      net::CompletionOnceCallback callback,
-      const net::HttpResponseHeaders* original_response_headers,
-      scoped_refptr<net::HttpResponseHeaders>* override_response_headers,
-      GURL* allowed_unsafe_redirect_url) = 0;
-  virtual void OnSendHeaders(extensions::WebRequestInfo* info,
-                             const network::ResourceRequest& request,
-                             const net::HttpRequestHeaders& headers) = 0;
-  virtual void OnBeforeRedirect(extensions::WebRequestInfo* info,
-                                const network::ResourceRequest& request,
-                                const GURL& new_location) = 0;
-  virtual void OnResponseStarted(extensions::WebRequestInfo* info,
-                                 const network::ResourceRequest& request) = 0;
-  virtual void OnErrorOccurred(extensions::WebRequestInfo* info,
-                               const network::ResourceRequest& request,
-                               int net_error) = 0;
-  virtual void OnCompleted(extensions::WebRequestInfo* info,
-                           const network::ResourceRequest& request,
-                           int net_error) = 0;
-  virtual void OnRequestWillBeDestroyed(extensions::WebRequestInfo* info) = 0;
-};
-
 // This class is responsible for following tasks when NetworkService is enabled:
 // 1. handling intercepted protocols;
 // 2. implementing webRequest module;
@@ -203,6 +161,7 @@ class ProxyingURLLoaderFactory
       WebRequestAPI* web_request_api,
       const HandlersMap& intercepted_handlers,
       int render_process_id,
+      uint64_t* request_id_generator,
       network::mojom::URLLoaderFactoryRequest loader_request,
       network::mojom::URLLoaderFactoryPtrInfo target_factory_info,
       mojo::PendingReceiver<network::mojom::TrustedURLLoaderHeaderClient>
@@ -254,6 +213,7 @@ class ProxyingURLLoaderFactory
   const int render_process_id_;
   mojo::BindingSet<network::mojom::URLLoaderFactory> proxy_bindings_;
   network::mojom::URLLoaderFactoryPtr target_factory_;
+  uint64_t* request_id_generator_;  // managed by AtomBrowserClient
   mojo::Receiver<network::mojom::TrustedURLLoaderHeaderClient>
       url_loader_header_client_receiver_{this};
   const content::ContentBrowserClient::URLLoaderFactoryType

+ 455 - 0
shell/browser/net/proxying_websocket.cc

@@ -0,0 +1,455 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "electron/shell/browser/net/proxying_websocket.h"
+
+#include <utility>
+
+#include "base/bind.h"
+#include "base/strings/string_util.h"
+#include "base/strings/stringprintf.h"
+#include "components/keyed_service/content/browser_context_keyed_service_shutdown_notifier_factory.h"
+#include "content/public/browser/browser_thread.h"
+#include "extensions/browser/extension_navigation_ui_data.h"
+#include "net/base/ip_endpoint.h"
+#include "net/http/http_util.h"
+
+namespace electron {
+
+ProxyingWebSocket::ProxyingWebSocket(
+    WebRequestAPI* web_request_api,
+    WebSocketFactory factory,
+    const network::ResourceRequest& request,
+    mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
+        handshake_client,
+    bool has_extra_headers,
+    int process_id,
+    int render_frame_id,
+    content::BrowserContext* browser_context,
+    uint64_t* request_id_generator)
+    : web_request_api_(web_request_api),
+      request_(request),
+      factory_(std::move(factory)),
+      forwarding_handshake_client_(std::move(handshake_client)),
+      request_headers_(request.headers),
+      has_extra_headers_(has_extra_headers),
+      info_(extensions::WebRequestInfoInitParams(
+          ++(*request_id_generator),
+          process_id,
+          render_frame_id,
+          nullptr,
+          MSG_ROUTING_NONE,
+          request,
+          /*is_download=*/false,
+          /*is_async=*/true,
+          /*is_service_worker_script=*/false)) {}
+
+ProxyingWebSocket::~ProxyingWebSocket() {
+  if (on_before_send_headers_callback_) {
+    std::move(on_before_send_headers_callback_)
+        .Run(net::ERR_ABORTED, base::nullopt);
+  }
+  if (on_headers_received_callback_) {
+    std::move(on_headers_received_callback_)
+        .Run(net::ERR_ABORTED, base::nullopt, GURL());
+  }
+}
+
+void ProxyingWebSocket::Start() {
+  // If the header client will be used, we start the request immediately, and
+  // OnBeforeSendHeaders and OnSendHeaders will be handled there. Otherwise,
+  // send these events before the request starts.
+  base::RepeatingCallback<void(int)> continuation;
+  if (has_extra_headers_) {
+    continuation = base::BindRepeating(
+        &ProxyingWebSocket::ContinueToStartRequest, weak_factory_.GetWeakPtr());
+  } else {
+    continuation =
+        base::BindRepeating(&ProxyingWebSocket::OnBeforeRequestComplete,
+                            weak_factory_.GetWeakPtr());
+  }
+
+  int result = web_request_api_->OnBeforeRequest(&info_, request_, continuation,
+                                                 &redirect_url_);
+
+  if (result == net::ERR_BLOCKED_BY_CLIENT) {
+    OnError(result);
+    return;
+  }
+
+  if (result == net::ERR_IO_PENDING) {
+    return;
+  }
+
+  DCHECK_EQ(net::OK, result);
+  continuation.Run(net::OK);
+}
+
+void ProxyingWebSocket::OnOpeningHandshakeStarted(
+    network::mojom::WebSocketHandshakeRequestPtr request) {
+  DCHECK(forwarding_handshake_client_);
+  forwarding_handshake_client_->OnOpeningHandshakeStarted(std::move(request));
+}
+
+void ProxyingWebSocket::OnResponseReceived(
+    network::mojom::WebSocketHandshakeResponsePtr response) {
+  DCHECK(forwarding_handshake_client_);
+
+  // response_.headers will be set in OnBeforeSendHeaders if
+  // |receiver_as_header_client_| is set.
+  if (!receiver_as_header_client_.is_bound()) {
+    response_.headers =
+        base::MakeRefCounted<net::HttpResponseHeaders>(base::StringPrintf(
+            "HTTP/%d.%d %d %s", response->http_version.major_value(),
+            response->http_version.minor_value(), response->status_code,
+            response->status_text.c_str()));
+    for (const auto& header : response->headers)
+      response_.headers->AddHeader(header->name + ": " + header->value);
+  }
+
+  response_.remote_endpoint = response->remote_endpoint;
+
+  // TODO(yhirano): OnResponseReceived is called with the original
+  // response headers. That means if OnHeadersReceived modified them the
+  // renderer won't see that modification. This is the opposite of http(s)
+  // requests.
+  forwarding_handshake_client_->OnResponseReceived(std::move(response));
+
+  if (!receiver_as_header_client_.is_bound() || response_.headers) {
+    ContinueToHeadersReceived();
+  } else {
+    waiting_for_header_client_headers_received_ = true;
+  }
+}
+
+void ProxyingWebSocket::ContinueToHeadersReceived() {
+  auto continuation =
+      base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedComplete,
+                          weak_factory_.GetWeakPtr());
+  info_.AddResponseInfoFromResourceResponse(response_);
+  int result = web_request_api_->OnHeadersReceived(
+      &info_, request_, continuation, response_.headers.get(),
+      &override_headers_, &redirect_url_);
+
+  if (result == net::ERR_BLOCKED_BY_CLIENT) {
+    OnError(result);
+    return;
+  }
+
+  PauseIncomingMethodCallProcessing();
+  if (result == net::ERR_IO_PENDING)
+    return;
+
+  DCHECK_EQ(net::OK, result);
+  OnHeadersReceivedComplete(net::OK);
+}
+
+void ProxyingWebSocket::OnConnectionEstablished(
+    mojo::PendingRemote<network::mojom::WebSocket> websocket,
+    mojo::PendingReceiver<network::mojom::WebSocketClient> client_receiver,
+    const std::string& selected_protocol,
+    const std::string& extensions,
+    mojo::ScopedDataPipeConsumerHandle readable) {
+  DCHECK(forwarding_handshake_client_);
+  DCHECK(!is_done_);
+  is_done_ = true;
+  web_request_api_->OnCompleted(&info_, request_, net::ERR_WS_UPGRADE);
+
+  forwarding_handshake_client_->OnConnectionEstablished(
+      std::move(websocket), std::move(client_receiver), selected_protocol,
+      extensions, std::move(readable));
+
+  // Deletes |this|.
+  delete this;
+}
+
+void ProxyingWebSocket::OnAuthRequired(
+    const net::AuthChallengeInfo& auth_info,
+    const scoped_refptr<net::HttpResponseHeaders>& headers,
+    const net::IPEndPoint& remote_endpoint,
+    OnAuthRequiredCallback callback) {
+  if (!callback) {
+    OnError(net::ERR_FAILED);
+    return;
+  }
+
+  response_.headers = headers;
+  response_.remote_endpoint = remote_endpoint;
+  auth_required_callback_ = std::move(callback);
+
+  auto continuation =
+      base::BindRepeating(&ProxyingWebSocket::OnHeadersReceivedCompleteForAuth,
+                          weak_factory_.GetWeakPtr(), auth_info);
+  info_.AddResponseInfoFromResourceResponse(response_);
+  int result = web_request_api_->OnHeadersReceived(
+      &info_, request_, continuation, response_.headers.get(),
+      &override_headers_, &redirect_url_);
+
+  if (result == net::ERR_BLOCKED_BY_CLIENT) {
+    OnError(result);
+    return;
+  }
+
+  PauseIncomingMethodCallProcessing();
+  if (result == net::ERR_IO_PENDING)
+    return;
+
+  DCHECK_EQ(net::OK, result);
+  OnHeadersReceivedCompleteForAuth(auth_info, net::OK);
+}
+
+void ProxyingWebSocket::OnBeforeSendHeaders(
+    const net::HttpRequestHeaders& headers,
+    OnBeforeSendHeadersCallback callback) {
+  DCHECK(receiver_as_header_client_.is_bound());
+
+  request_headers_ = headers;
+  on_before_send_headers_callback_ = std::move(callback);
+  OnBeforeRequestComplete(net::OK);
+}
+
+void ProxyingWebSocket::OnHeadersReceived(const std::string& headers,
+                                          OnHeadersReceivedCallback callback) {
+  DCHECK(receiver_as_header_client_.is_bound());
+
+  // Note: since there are different pipes used for WebSocketClient and
+  // TrustedHeaderClient, there are no guarantees whether this or
+  // OnResponseReceived are called first.
+  on_headers_received_callback_ = std::move(callback);
+  response_.headers = base::MakeRefCounted<net::HttpResponseHeaders>(headers);
+
+  if (!waiting_for_header_client_headers_received_)
+    return;
+
+  waiting_for_header_client_headers_received_ = false;
+  ContinueToHeadersReceived();
+}
+
+void ProxyingWebSocket::StartProxying(
+    WebRequestAPI* web_request_api,
+    WebSocketFactory factory,
+    const GURL& url,
+    const GURL& site_for_cookies,
+    const base::Optional<std::string>& user_agent,
+    mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
+        handshake_client,
+    bool has_extra_headers,
+    int process_id,
+    int render_frame_id,
+    const url::Origin& origin,
+    content::BrowserContext* browser_context,
+    uint64_t* request_id_generator) {
+  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
+  network::ResourceRequest request;
+  request.url = url;
+  request.site_for_cookies = site_for_cookies;
+  if (user_agent) {
+    request.headers.SetHeader(net::HttpRequestHeaders::kUserAgent, *user_agent);
+  }
+  request.request_initiator = origin;
+
+  auto* proxy = new ProxyingWebSocket(
+      web_request_api, std::move(factory), request, std::move(handshake_client),
+      has_extra_headers, process_id, render_frame_id, browser_context,
+      request_id_generator);
+  proxy->Start();
+}
+
+void ProxyingWebSocket::OnBeforeRequestComplete(int error_code) {
+  DCHECK(receiver_as_header_client_.is_bound() ||
+         !receiver_as_handshake_client_.is_bound());
+  DCHECK(info_.url.SchemeIsWSOrWSS());
+  if (error_code != net::OK) {
+    OnError(error_code);
+    return;
+  }
+
+  auto continuation =
+      base::BindRepeating(&ProxyingWebSocket::OnBeforeSendHeadersComplete,
+                          weak_factory_.GetWeakPtr());
+
+  info_.AddResponseInfoFromResourceResponse(response_);
+  int result = web_request_api_->OnBeforeSendHeaders(
+      &info_, request_, continuation, &request_headers_);
+
+  if (result == net::ERR_BLOCKED_BY_CLIENT) {
+    OnError(result);
+    return;
+  }
+
+  if (result == net::ERR_IO_PENDING)
+    return;
+
+  DCHECK_EQ(net::OK, result);
+  OnBeforeSendHeadersComplete(std::set<std::string>(), std::set<std::string>(),
+                              net::OK);
+}
+
+void ProxyingWebSocket::OnBeforeSendHeadersComplete(
+    const std::set<std::string>& removed_headers,
+    const std::set<std::string>& set_headers,
+    int error_code) {
+  DCHECK(receiver_as_header_client_.is_bound() ||
+         !receiver_as_handshake_client_.is_bound());
+  if (error_code != net::OK) {
+    OnError(error_code);
+    return;
+  }
+
+  if (receiver_as_header_client_.is_bound()) {
+    CHECK(on_before_send_headers_callback_);
+    std::move(on_before_send_headers_callback_)
+        .Run(error_code, request_headers_);
+  }
+
+  info_.AddResponseInfoFromResourceResponse(response_);
+  web_request_api_->OnSendHeaders(&info_, request_, request_headers_);
+
+  if (!receiver_as_header_client_.is_bound())
+    ContinueToStartRequest(net::OK);
+}
+
+void ProxyingWebSocket::ContinueToStartRequest(int error_code) {
+  if (error_code != net::OK) {
+    OnError(error_code);
+    return;
+  }
+
+  base::flat_set<std::string> used_header_names;
+  std::vector<network::mojom::HttpHeaderPtr> additional_headers;
+  for (net::HttpRequestHeaders::Iterator it(request_headers_); it.GetNext();) {
+    additional_headers.push_back(
+        network::mojom::HttpHeader::New(it.name(), it.value()));
+    used_header_names.insert(base::ToLowerASCII(it.name()));
+  }
+  for (const auto& header : additional_headers_) {
+    if (!used_header_names.contains(base::ToLowerASCII(header->name))) {
+      additional_headers.push_back(
+          network::mojom::HttpHeader::New(header->name, header->value));
+    }
+  }
+
+  mojo::PendingRemote<network::mojom::TrustedHeaderClient>
+      trusted_header_client = mojo::NullRemote();
+  if (has_extra_headers_) {
+    trusted_header_client =
+        receiver_as_header_client_.BindNewPipeAndPassRemote();
+  }
+
+  std::move(factory_).Run(
+      info_.url, std::move(additional_headers),
+      receiver_as_handshake_client_.BindNewPipeAndPassRemote(),
+      receiver_as_auth_handler_.BindNewPipeAndPassRemote(),
+      std::move(trusted_header_client));
+
+  // Here we detect mojo connection errors on |receiver_as_handshake_client_|.
+  // See also CreateWebSocket in
+  // //network/services/public/mojom/network_context.mojom.
+  receiver_as_handshake_client_.set_disconnect_with_reason_handler(
+      base::BindOnce(&ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason,
+                     base::Unretained(this)));
+  forwarding_handshake_client_.set_disconnect_handler(base::BindOnce(
+      &ProxyingWebSocket::OnMojoConnectionError, base::Unretained(this)));
+}
+
+void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) {
+  if (error_code != net::OK) {
+    OnError(error_code);
+    return;
+  }
+
+  if (on_headers_received_callback_) {
+    base::Optional<std::string> headers;
+    if (override_headers_)
+      headers = override_headers_->raw_headers();
+    std::move(on_headers_received_callback_).Run(net::OK, headers, GURL());
+  }
+
+  if (override_headers_) {
+    response_.headers = override_headers_;
+    override_headers_ = nullptr;
+  }
+
+  ResumeIncomingMethodCallProcessing();
+  info_.AddResponseInfoFromResourceResponse(response_);
+  web_request_api_->OnResponseStarted(&info_, request_);
+}
+
+void ProxyingWebSocket::OnAuthRequiredComplete(
+    net::NetworkDelegate::AuthRequiredResponse rv) {
+  CHECK(auth_required_callback_);
+  ResumeIncomingMethodCallProcessing();
+  switch (rv) {
+    case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_NO_ACTION:
+    case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_CANCEL_AUTH:
+      std::move(auth_required_callback_).Run(base::nullopt);
+      break;
+
+    case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_SET_AUTH:
+      std::move(auth_required_callback_).Run(auth_credentials_);
+      break;
+    case net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_IO_PENDING:
+      NOTREACHED();
+      break;
+  }
+}
+
+void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth(
+    const net::AuthChallengeInfo& auth_info,
+    int rv) {
+  if (rv != net::OK) {
+    OnError(rv);
+    return;
+  }
+  ResumeIncomingMethodCallProcessing();
+  info_.AddResponseInfoFromResourceResponse(response_);
+
+  auto continuation = base::BindRepeating(
+      &ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr());
+  auto auth_rv = net::NetworkDelegate::AUTH_REQUIRED_RESPONSE_IO_PENDING;
+  PauseIncomingMethodCallProcessing();
+
+  OnAuthRequiredComplete(auth_rv);
+}
+
+void ProxyingWebSocket::PauseIncomingMethodCallProcessing() {
+  receiver_as_handshake_client_.Pause();
+  receiver_as_auth_handler_.Pause();
+  if (receiver_as_header_client_.is_bound())
+    receiver_as_header_client_.Pause();
+}
+
+void ProxyingWebSocket::ResumeIncomingMethodCallProcessing() {
+  receiver_as_handshake_client_.Resume();
+  receiver_as_auth_handler_.Resume();
+  if (receiver_as_header_client_.is_bound())
+    receiver_as_header_client_.Resume();
+}
+
+void ProxyingWebSocket::OnError(int error_code) {
+  if (!is_done_) {
+    is_done_ = true;
+    web_request_api_->OnErrorOccurred(&info_, request_, error_code);
+  }
+
+  // Deletes |this|.
+  delete this;
+}
+
+void ProxyingWebSocket::OnMojoConnectionErrorWithCustomReason(
+    uint32_t custom_reason,
+    const std::string& description) {
+  // Here we want to nofiy the custom reason to the client, which is why
+  // we reset |forwarding_handshake_client_| manually.
+  forwarding_handshake_client_.ResetWithReason(custom_reason, description);
+  OnError(net::ERR_FAILED);
+  // Deletes |this|.
+}
+
+void ProxyingWebSocket::OnMojoConnectionError() {
+  OnError(net::ERR_FAILED);
+  // Deletes |this|.
+}
+
+}  // namespace electron

+ 164 - 0
shell/browser/net/proxying_websocket.h

@@ -0,0 +1,164 @@
+// Copyright 2020 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHELL_BROWSER_NET_PROXYING_WEBSOCKET_H_
+#define SHELL_BROWSER_NET_PROXYING_WEBSOCKET_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "base/optional.h"
+#include "components/keyed_service/core/keyed_service_shutdown_notifier.h"
+#include "extensions/browser/api/web_request/web_request_api.h"
+#include "extensions/browser/api/web_request/web_request_info.h"
+#include "mojo/public/cpp/bindings/pending_receiver.h"
+#include "mojo/public/cpp/bindings/receiver.h"
+#include "mojo/public/cpp/bindings/remote.h"
+#include "net/base/network_delegate.h"
+#include "services/network/public/cpp/resource_request.h"
+#include "services/network/public/cpp/resource_response.h"
+#include "services/network/public/mojom/network_context.mojom.h"
+#include "services/network/public/mojom/websocket.mojom.h"
+#include "shell/browser/net/web_request_api_interface.h"
+#include "url/gurl.h"
+#include "url/origin.h"
+
+namespace electron {
+
+// A ProxyingWebSocket proxies a WebSocket connection and dispatches
+// WebRequest API events.
+//
+// The code is referenced from the
+// extensions::WebRequestProxyingWebSocket class.
+class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient,
+                          public network::mojom::AuthenticationHandler,
+                          public network::mojom::TrustedHeaderClient {
+ public:
+  using WebSocketFactory = content::ContentBrowserClient::WebSocketFactory;
+
+  ProxyingWebSocket(
+      WebRequestAPI* web_request_api,
+      WebSocketFactory factory,
+      const network::ResourceRequest& request,
+      mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
+          handshake_client,
+      bool has_extra_headers,
+      int process_id,
+      int render_frame_id,
+      content::BrowserContext* browser_context,
+      uint64_t* request_id_generator);
+  ~ProxyingWebSocket() override;
+
+  void Start();
+
+  // network::mojom::WebSocketHandshakeClient methods:
+  void OnOpeningHandshakeStarted(
+      network::mojom::WebSocketHandshakeRequestPtr request) override;
+  void OnResponseReceived(
+      network::mojom::WebSocketHandshakeResponsePtr response) override;
+  void OnConnectionEstablished(
+      mojo::PendingRemote<network::mojom::WebSocket> websocket,
+      mojo::PendingReceiver<network::mojom::WebSocketClient> client_receiver,
+      const std::string& selected_protocol,
+      const std::string& extensions,
+      mojo::ScopedDataPipeConsumerHandle readable) override;
+
+  // network::mojom::AuthenticationHandler method:
+  void OnAuthRequired(const net::AuthChallengeInfo& auth_info,
+                      const scoped_refptr<net::HttpResponseHeaders>& headers,
+                      const net::IPEndPoint& remote_endpoint,
+                      OnAuthRequiredCallback callback) override;
+
+  // network::mojom::TrustedHeaderClient methods:
+  void OnBeforeSendHeaders(const net::HttpRequestHeaders& headers,
+                           OnBeforeSendHeadersCallback callback) override;
+  void OnHeadersReceived(const std::string& headers,
+                         OnHeadersReceivedCallback callback) override;
+
+  static void StartProxying(
+      WebRequestAPI* web_request_api,
+      WebSocketFactory factory,
+      const GURL& url,
+      const GURL& site_for_cookies,
+      const base::Optional<std::string>& user_agent,
+      mojo::PendingRemote<network::mojom::WebSocketHandshakeClient>
+          handshake_client,
+      bool has_extra_headers,
+      int process_id,
+      int render_frame_id,
+      const url::Origin& origin,
+      content::BrowserContext* browser_context,
+      uint64_t* request_id_generator);
+
+  WebRequestAPI* web_request_api() { return web_request_api_; }
+
+ private:
+  void OnBeforeRequestComplete(int error_code);
+  void OnBeforeSendHeadersComplete(const std::set<std::string>& removed_headers,
+                                   const std::set<std::string>& set_headers,
+                                   int error_code);
+  void ContinueToStartRequest(int error_code);
+  void OnHeadersReceivedComplete(int error_code);
+  void ContinueToHeadersReceived();
+  void OnAuthRequiredComplete(net::NetworkDelegate::AuthRequiredResponse rv);
+  void OnHeadersReceivedCompleteForAuth(const net::AuthChallengeInfo& auth_info,
+                                        int rv);
+
+  void PauseIncomingMethodCallProcessing();
+  void ResumeIncomingMethodCallProcessing();
+  void OnError(int result);
+  // This is used for detecting errors on mojo connection with the network
+  // service.
+  void OnMojoConnectionErrorWithCustomReason(uint32_t custom_reason,
+                                             const std::string& description);
+  // This is used for detecting errors on mojo connection with original client
+  // (i.e., renderer).
+  void OnMojoConnectionError();
+
+  // Passed from api::WebRequest.
+  WebRequestAPI* web_request_api_;
+
+  // Saved to feed the api::WebRequest.
+  network::ResourceRequest request_;
+
+  WebSocketFactory factory_;
+  mojo::Remote<network::mojom::WebSocketHandshakeClient>
+      forwarding_handshake_client_;
+  mojo::Receiver<network::mojom::WebSocketHandshakeClient>
+      receiver_as_handshake_client_{this};
+  mojo::Receiver<network::mojom::AuthenticationHandler>
+      receiver_as_auth_handler_{this};
+  mojo::Receiver<network::mojom::TrustedHeaderClient>
+      receiver_as_header_client_{this};
+
+  net::HttpRequestHeaders request_headers_;
+  network::ResourceResponseHead response_;
+  net::AuthCredentials auth_credentials_;
+  OnAuthRequiredCallback auth_required_callback_;
+  scoped_refptr<net::HttpResponseHeaders> override_headers_;
+  std::vector<network::mojom::HttpHeaderPtr> additional_headers_;
+
+  OnBeforeSendHeadersCallback on_before_send_headers_callback_;
+  OnHeadersReceivedCallback on_headers_received_callback_;
+
+  GURL redirect_url_;
+  bool is_done_ = false;
+  bool waiting_for_header_client_headers_received_ = false;
+  bool has_extra_headers_;
+
+  extensions::WebRequestInfo info_;
+
+  // Notifies the proxy that the browser context has been shutdown.
+  std::unique_ptr<KeyedServiceShutdownNotifier::Subscription>
+      shutdown_notifier_;
+
+  base::WeakPtrFactory<ProxyingWebSocket> weak_factory_{this};
+  DISALLOW_COPY_AND_ASSIGN(ProxyingWebSocket);
+};
+
+}  // namespace electron
+
+#endif  // SHELL_BROWSER_NET_PROXYING_WEBSOCKET_H_

+ 61 - 0
shell/browser/net/web_request_api_interface.h

@@ -0,0 +1,61 @@
+// Copyright (c) 2020 GitHub, Inc.
+// Use of this source code is governed by the MIT license that can be
+// found in the LICENSE file.
+
+#ifndef SHELL_BROWSER_NET_WEB_REQUEST_API_INTERFACE_H_
+#define SHELL_BROWSER_NET_WEB_REQUEST_API_INTERFACE_H_
+
+#include <set>
+#include <string>
+
+#include "extensions/browser/api/web_request/web_request_info.h"
+#include "services/network/public/cpp/resource_request.h"
+
+namespace electron {
+
+// Defines the interface for WebRequest API, implemented by api::WebRequestNS.
+class WebRequestAPI {
+ public:
+  virtual ~WebRequestAPI() {}
+
+  using BeforeSendHeadersCallback =
+      base::OnceCallback<void(const std::set<std::string>& removed_headers,
+                              const std::set<std::string>& set_headers,
+                              int error_code)>;
+
+  virtual bool HasListener() const = 0;
+  virtual int OnBeforeRequest(extensions::WebRequestInfo* info,
+                              const network::ResourceRequest& request,
+                              net::CompletionOnceCallback callback,
+                              GURL* new_url) = 0;
+  virtual int OnBeforeSendHeaders(extensions::WebRequestInfo* info,
+                                  const network::ResourceRequest& request,
+                                  BeforeSendHeadersCallback callback,
+                                  net::HttpRequestHeaders* headers) = 0;
+  virtual int OnHeadersReceived(
+      extensions::WebRequestInfo* info,
+      const network::ResourceRequest& request,
+      net::CompletionOnceCallback callback,
+      const net::HttpResponseHeaders* original_response_headers,
+      scoped_refptr<net::HttpResponseHeaders>* override_response_headers,
+      GURL* allowed_unsafe_redirect_url) = 0;
+  virtual void OnSendHeaders(extensions::WebRequestInfo* info,
+                             const network::ResourceRequest& request,
+                             const net::HttpRequestHeaders& headers) = 0;
+  virtual void OnBeforeRedirect(extensions::WebRequestInfo* info,
+                                const network::ResourceRequest& request,
+                                const GURL& new_location) = 0;
+  virtual void OnResponseStarted(extensions::WebRequestInfo* info,
+                                 const network::ResourceRequest& request) = 0;
+  virtual void OnErrorOccurred(extensions::WebRequestInfo* info,
+                               const network::ResourceRequest& request,
+                               int net_error) = 0;
+  virtual void OnCompleted(extensions::WebRequestInfo* info,
+                           const network::ResourceRequest& request,
+                           int net_error) = 0;
+  virtual void OnRequestWillBeDestroyed(extensions::WebRequestInfo* info) = 0;
+};
+
+}  // namespace electron
+
+#endif  // SHELL_BROWSER_NET_WEB_REQUEST_API_INTERFACE_H_

+ 99 - 1
spec-main/api-web-request-spec.ts

@@ -2,8 +2,10 @@ import { expect } from 'chai'
 import * as http from 'http'
 import * as qs from 'querystring'
 import * as path from 'path'
-import { session, WebContents, webContents } from 'electron'
+import * as WebSocket from 'ws'
+import { ipcMain, session, WebContents, webContents } from 'electron'
 import { AddressInfo } from 'net';
+import { emittedOnce } from './events-helpers'
 
 const fixturesPath = path.resolve(__dirname, '..', 'spec', 'fixtures')
 
@@ -339,4 +341,100 @@ describe('webRequest module', () => {
       await expect(ajax(defaultURL)).to.eventually.be.rejectedWith('404')
     })
   })
+
+  describe('WebSocket connections', () => {
+    it('can be proxyed', async () => {
+      // Setup server.
+      const reqHeaders : { [key: string] : any } = {}
+      const server = http.createServer((req, res) => {
+        reqHeaders[req.url!] = req.headers
+        res.setHeader('foo1', 'bar1')
+        res.end('ok')
+      })
+      const wss = new WebSocket.Server({ noServer: true })
+      wss.on('connection', function connection (ws) {
+        ws.on('message', function incoming (message) {
+          if (message === 'foo') {
+            ws.send('bar')
+          }
+        })
+      })
+      server.on('upgrade', function upgrade (request, socket, head) {
+        const pathname = require('url').parse(request.url).pathname
+        if (pathname === '/websocket') {
+          reqHeaders[request.url] = request.headers
+          wss.handleUpgrade(request, socket, head, function done (ws) {
+            wss.emit('connection', ws, request)
+          })
+        }
+      })
+
+      // Start server.
+      await new Promise(resolve => server.listen(0, '127.0.0.1', resolve))
+      const port = String((server.address() as AddressInfo).port)
+
+      // Use a separate session for testing.
+      const ses = session.fromPartition('WebRequestWebSocket')
+
+      // Setup listeners.
+      const receivedHeaders : { [key: string] : any } = {}
+      ses.webRequest.onBeforeSendHeaders((details, callback) => {
+        details.requestHeaders.foo = 'bar'
+        callback({ requestHeaders: details.requestHeaders })
+      })
+      ses.webRequest.onHeadersReceived((details, callback) => {
+        const pathname = require('url').parse(details.url).pathname
+        receivedHeaders[pathname] = details.responseHeaders
+        callback({ cancel: false })
+      })
+      ses.webRequest.onResponseStarted((details) => {
+        if (details.url.startsWith('ws://')) {
+          expect(details.responseHeaders!['Connection'][0]).be.equal('Upgrade')
+        } else if (details.url.startsWith('http')) {
+          expect(details.responseHeaders!['foo1'][0]).be.equal('bar1')
+        }
+      })
+      ses.webRequest.onSendHeaders((details) => {
+        if (details.url.startsWith('ws://')) {
+          expect(details.requestHeaders['foo']).be.equal('bar')
+          expect(details.requestHeaders['Upgrade']).be.equal('websocket')
+        } else if (details.url.startsWith('http')) {
+          expect(details.requestHeaders['foo']).be.equal('bar')
+        }
+      })
+      ses.webRequest.onCompleted((details) => {
+        if (details.url.startsWith('ws://')) {
+          expect(details['error']).be.equal('net::ERR_WS_UPGRADE')
+        } else if (details.url.startsWith('http')) {
+          expect(details['error']).be.equal('net::OK')
+        }
+      })
+
+      const contents = (webContents as any).create({
+        session: ses,
+        nodeIntegration: true,
+        webSecurity: false
+      })
+
+      // Cleanup.
+      after(() => {
+        contents.destroy()
+        server.close()
+        ses.webRequest.onBeforeRequest(null)
+        ses.webRequest.onBeforeSendHeaders(null)
+        ses.webRequest.onHeadersReceived(null)
+        ses.webRequest.onResponseStarted(null)
+        ses.webRequest.onSendHeaders(null)
+        ses.webRequest.onCompleted(null)
+      })
+
+      contents.loadFile(path.join(__dirname, 'fixtures', 'api', 'webrequest.html'), { query: { port } })
+      await emittedOnce(ipcMain, 'websocket-success')
+
+      expect(receivedHeaders['/websocket']['Upgrade'][0]).to.equal('websocket')
+      expect(receivedHeaders['/']['foo1'][0]).to.equal('bar1')
+      expect(reqHeaders['/websocket']['foo']).to.equal('bar')
+      expect(reqHeaders['/']['foo']).to.equal('bar')
+    })
+  })
 })

+ 27 - 0
spec-main/fixtures/api/webrequest.html

@@ -0,0 +1,27 @@
+<script>
+  var url = new URL(location.href)
+  const port = new URLSearchParams(url.search).get("port")
+  const ipcRenderer = require('electron').ipcRenderer
+  let count = 0
+  function checkFinish() {
+    count++
+    if (count === 2) {
+      ipcRenderer.send('websocket-success')
+    }
+  }
+
+  var conn = new WebSocket(`ws://127.0.0.1:${port}/websocket`)
+  conn.onopen = data => conn.send('foo')
+  conn.onmessage = wsMsg
+  function wsMsg(msg) {
+    if (msg.data === 'bar') {
+      checkFinish()
+    } else {
+      ipcRenderer.send('fail')
+    }
+  }
+
+  fetch(`http://127.0.0.1:${port}/`).then(() => {
+    checkFinish()
+  })
+</script>

+ 5 - 1
spec-main/package.json

@@ -2,5 +2,9 @@
   "name": "electron-test-main",
   "productName": "Electron Test Main",
   "main": "index.js",
-  "version": "0.1.0"
+  "version": "0.1.0",
+  "devDependencies": {
+    "@types/ws": "^7.2.0",
+    "ws": "^7.2.1"
+  }
 }

+ 20 - 0
spec-main/yarn.lock

@@ -0,0 +1,20 @@
+# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
+# yarn lockfile v1
+
+
+"@types/node@*":
+  version "13.7.0"
+  resolved "https://registry.yarnpkg.com/@types/node/-/node-13.7.0.tgz#b417deda18cf8400f278733499ad5547ed1abec4"
+  integrity sha512-GnZbirvmqZUzMgkFn70c74OQpTTUcCzlhQliTzYjQMqg+hVKcDnxdL19Ne3UdYzdMA/+W3eb646FWn/ZaT1NfQ==
+
+"@types/ws@^7.2.0":
+  version "7.2.1"
+  resolved "https://registry.yarnpkg.com/@types/ws/-/ws-7.2.1.tgz#b800f2b8aee694e2b581113643e20d79dd3b8556"
+  integrity sha512-UEmRNbXFGvfs/sLncf01GuVv6U1mZP3Df0iXWx4kUlikJxbFyFADp95mDn1XDTE2mXpzzoHcKlfFcbytLq4vaA==
+  dependencies:
+    "@types/node" "*"
+
+ws@^7.2.1:
+  version "7.2.1"
+  resolved "https://registry.yarnpkg.com/ws/-/ws-7.2.1.tgz#03ed52423cd744084b2cf42ed197c8b65a936b8e"
+  integrity sha512-sucePNSafamSKoOqoNfBd8V0StlkzJKL2ZAhGQinCfNQ+oacw+Pk7lcdAElecBF2VkLNZRiIb5Oi1Q5lVUVt2A==