Browse Source

fix: use `BlockedRequest` struct to handle `webRequest` data (#42750)

* refactor: use BlockedRequest model to handle webRequest

Co-authored-by: Shelley Vohr <[email protected]>

* refactor: finish de-templating

Co-authored-by: Shelley Vohr <[email protected]>

* chore: address some feedback from review

Co-authored-by: Shelley Vohr <[email protected]>

---------

Co-authored-by: trop[bot] <37223003+trop[bot]@users.noreply.github.com>
Co-authored-by: Shelley Vohr <[email protected]>
trop[bot] 9 months ago
parent
commit
57e859d0af

+ 274 - 101
shell/browser/api/electron_api_web_request.cc

@@ -185,37 +185,41 @@ void FillDetails(gin_helper::Dictionary* details, Arg arg, Args... args) {
   FillDetails(details, args...);
 }
 
-// Fill the native types with the result from the response object.
-void ReadFromResponse(v8::Isolate* isolate,
-                      gin::Dictionary* response,
-                      GURL* new_location) {
-  response->Get("redirectURL", new_location);
-}
-
-void ReadFromResponse(v8::Isolate* isolate,
-                      gin::Dictionary* response,
-                      net::HttpRequestHeaders* headers) {
-  v8::Local<v8::Value> value;
-  if (response->Get("requestHeaders", &value) && value->IsObject()) {
-    headers->Clear();
-    gin::Converter<net::HttpRequestHeaders>::FromV8(isolate, value, headers);
-  }
-}
+// Modified from extensions/browser/api/web_request/web_request_api_helpers.cc.
+std::pair<std::set<std::string>, std::set<std::string>>
+CalculateOnBeforeSendHeadersDelta(const net::HttpRequestHeaders* old_headers,
+                                  const net::HttpRequestHeaders* new_headers) {
+  // Newly introduced or overridden request headers.
+  std::set<std::string> modified_request_headers;
+  // Keys of request headers to be deleted.
+  std::set<std::string> deleted_request_headers;
+
+  // The event listener might not have passed any new headers if it
+  // just wanted to cancel the request.
+  if (new_headers) {
+    // Find deleted headers.
+    {
+      net::HttpRequestHeaders::Iterator i(*old_headers);
+      while (i.GetNext()) {
+        if (!new_headers->HasHeader(i.name())) {
+          deleted_request_headers.insert(i.name());
+        }
+      }
+    }
 
-void ReadFromResponse(v8::Isolate* isolate,
-                      gin::Dictionary* response,
-                      const std::pair<scoped_refptr<net::HttpResponseHeaders>*,
-                                      const std::string>& headers) {
-  std::string status_line;
-  if (!response->Get("statusLine", &status_line))
-    status_line = headers.second;
-  v8::Local<v8::Value> value;
-  if (response->Get("responseHeaders", &value) && value->IsObject()) {
-    *headers.first = new net::HttpResponseHeaders("");
-    (*headers.first)->ReplaceStatusLine(status_line);
-    gin::Converter<net::HttpResponseHeaders*>::FromV8(isolate, value,
-                                                      (*headers.first).get());
+    // Find modified headers.
+    {
+      net::HttpRequestHeaders::Iterator i(*new_headers);
+      while (i.GetNext()) {
+        std::string value;
+        if (!old_headers->GetHeader(i.name(), &value) || i.value() != value) {
+          modified_request_headers.insert(i.name());
+        }
+      }
+    }
   }
+
+  return std::make_pair(modified_request_headers, deleted_request_headers);
 }
 
 }  // namespace
@@ -260,6 +264,24 @@ bool WebRequest::RequestFilter::MatchesRequest(
   return MatchesURL(info->url) && MatchesType(info->web_request_type);
 }
 
+struct WebRequest::BlockedRequest {
+  BlockedRequest() = default;
+  raw_ptr<const extensions::WebRequestInfo> request = nullptr;
+  net::CompletionOnceCallback callback;
+  // Only used for onBeforeSendHeaders.
+  BeforeSendHeadersCallback before_send_headers_callback;
+  // Only used for onBeforeSendHeaders.
+  raw_ptr<net::HttpRequestHeaders> request_headers = nullptr;
+  // Only used for onHeadersReceived.
+  scoped_refptr<const net::HttpResponseHeaders> original_response_headers;
+  // Only used for onHeadersReceived.
+  raw_ptr<scoped_refptr<net::HttpResponseHeaders>> override_response_headers =
+      nullptr;
+  std::string status_line;
+  // Only used for onBeforeRequest.
+  raw_ptr<GURL> new_url = nullptr;
+};
+
 WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_,
                                                    SimpleListener listener_)
     : filter(std::move(filter_)), listener(listener_) {}
@@ -320,19 +342,152 @@ int WebRequest::OnBeforeRequest(extensions::WebRequestInfo* info,
                                 const network::ResourceRequest& request,
                                 net::CompletionOnceCallback callback,
                                 GURL* new_url) {
-  return HandleResponseEvent(ResponseEvent::kOnBeforeRequest, info,
-                             std::move(callback), new_url, request);
+  return HandleOnBeforeRequestResponseEvent(info, request, std::move(callback),
+                                            new_url);
+}
+
+int WebRequest::HandleOnBeforeRequestResponseEvent(
+    extensions::WebRequestInfo* request_info,
+    const network::ResourceRequest& request,
+    net::CompletionOnceCallback callback,
+    GURL* new_url) {
+  const auto iter = response_listeners_.find(ResponseEvent::kOnBeforeRequest);
+  if (iter == std::end(response_listeners_))
+    return net::OK;
+
+  const auto& info = iter->second;
+  if (!info.filter.MatchesRequest(request_info))
+    return net::OK;
+
+  BlockedRequest blocked_request;
+  blocked_request.callback = std::move(callback);
+  blocked_request.new_url = new_url;
+  blocked_requests_[request_info->id] = std::move(blocked_request);
+
+  v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
+  v8::HandleScope handle_scope(isolate);
+  gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
+  FillDetails(&details, request_info, request, *new_url);
+
+  ResponseCallback response =
+      base::BindOnce(&WebRequest::OnBeforeRequestListenerResult,
+                     base::Unretained(this), request_info->id);
+  info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
+  return net::ERR_IO_PENDING;
+}
+
+void WebRequest::OnBeforeRequestListenerResult(uint64_t id,
+                                               v8::Local<v8::Value> response) {
+  const auto iter = blocked_requests_.find(id);
+  if (iter == std::end(blocked_requests_))
+    return;
+
+  auto& request = iter->second;
+
+  int result = net::OK;
+  if (response->IsObject()) {
+    v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
+    gin::Dictionary dict(isolate, response.As<v8::Object>());
+
+    bool cancel = false;
+    dict.Get("cancel", &cancel);
+    if (cancel) {
+      result = net::ERR_BLOCKED_BY_CLIENT;
+    } else {
+      dict.Get("redirectURL", request.new_url.get());
+    }
+  }
+
+  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
+      FROM_HERE, base::BindOnce(std::move(request.callback), result));
+  blocked_requests_.erase(iter);
 }
 
 int WebRequest::OnBeforeSendHeaders(extensions::WebRequestInfo* info,
                                     const network::ResourceRequest& request,
                                     BeforeSendHeadersCallback callback,
                                     net::HttpRequestHeaders* headers) {
-  return HandleResponseEvent(
-      ResponseEvent::kOnBeforeSendHeaders, info,
-      base::BindOnce(std::move(callback), std::set<std::string>(),
-                     std::set<std::string>()),
-      headers, request, *headers);
+  return HandleOnBeforeSendHeadersResponseEvent(info, request,
+                                                std::move(callback), headers);
+}
+
+int WebRequest::HandleOnBeforeSendHeadersResponseEvent(
+    extensions::WebRequestInfo* request_info,
+    const network::ResourceRequest& request,
+    BeforeSendHeadersCallback callback,
+    net::HttpRequestHeaders* headers) {
+  const auto iter =
+      response_listeners_.find(ResponseEvent::kOnBeforeSendHeaders);
+  if (iter == std::end(response_listeners_))
+    return net::OK;
+
+  const auto& info = iter->second;
+  if (!info.filter.MatchesRequest(request_info))
+    return net::OK;
+
+  BlockedRequest blocked_request;
+  blocked_request.before_send_headers_callback = std::move(callback);
+  blocked_request.request_headers = headers;
+  blocked_requests_[request_info->id] = std::move(blocked_request);
+
+  v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
+  v8::HandleScope handle_scope(isolate);
+  gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
+  FillDetails(&details, request_info, request, *headers);
+
+  ResponseCallback response =
+      base::BindOnce(&WebRequest::OnBeforeSendHeadersListenerResult,
+                     base::Unretained(this), request_info->id);
+  info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
+  return net::ERR_IO_PENDING;
+}
+
+void WebRequest::OnBeforeSendHeadersListenerResult(
+    uint64_t id,
+    v8::Local<v8::Value> response) {
+  const auto iter = blocked_requests_.find(id);
+  if (iter == std::end(blocked_requests_))
+    return;
+
+  auto& request = iter->second;
+
+  net::HttpRequestHeaders* old_headers = request.request_headers;
+  net::HttpRequestHeaders new_headers;
+
+  int result = net::OK;
+  bool user_modified_headers = false;
+  if (response->IsObject()) {
+    v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
+    gin::Dictionary dict(isolate, response.As<v8::Object>());
+
+    bool cancel = false;
+    dict.Get("cancel", &cancel);
+    if (cancel) {
+      result = net::ERR_BLOCKED_BY_CLIENT;
+    } else {
+      v8::Local<v8::Value> value;
+      if (dict.Get("requestHeaders", &value) && value->IsObject()) {
+        user_modified_headers = true;
+        gin::Converter<net::HttpRequestHeaders>::FromV8(isolate, value,
+                                                        &new_headers);
+      }
+    }
+  }
+
+  // If the user passes |cancel|, |new_headers| should be nullptr.
+  const auto updated_headers = CalculateOnBeforeSendHeadersDelta(
+      old_headers,
+      result == net::ERR_BLOCKED_BY_CLIENT ? nullptr : &new_headers);
+
+  // Leave |request.request_headers| unchanged if the user didn't modify it.
+  if (user_modified_headers)
+    request.request_headers->Swap(&new_headers);
+
+  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
+      FROM_HERE,
+      base::BindOnce(std::move(request.before_send_headers_callback),
+                     updated_headers.first, updated_headers.second, result));
+  blocked_requests_.erase(iter);
 }
 
 int WebRequest::OnHeadersReceived(
@@ -342,12 +497,86 @@ int WebRequest::OnHeadersReceived(
     const net::HttpResponseHeaders* original_response_headers,
     scoped_refptr<net::HttpResponseHeaders>* override_response_headers,
     GURL* allowed_unsafe_redirect_url) {
-  const std::string& status_line =
-      original_response_headers ? original_response_headers->GetStatusLine()
-                                : std::string();
-  return HandleResponseEvent(
-      ResponseEvent::kOnHeadersReceived, info, std::move(callback),
-      std::make_pair(override_response_headers, status_line), request);
+  return HandleOnHeadersReceivedResponseEvent(
+      info, request, std::move(callback), original_response_headers,
+      override_response_headers);
+}
+
+int WebRequest::HandleOnHeadersReceivedResponseEvent(
+    extensions::WebRequestInfo* request_info,
+    const network::ResourceRequest& request,
+    net::CompletionOnceCallback callback,
+    const net::HttpResponseHeaders* original_response_headers,
+    scoped_refptr<net::HttpResponseHeaders>* override_response_headers) {
+  const auto iter = response_listeners_.find(ResponseEvent::kOnHeadersReceived);
+  if (iter == std::end(response_listeners_))
+    return net::OK;
+
+  const auto& info = iter->second;
+  if (!info.filter.MatchesRequest(request_info))
+    return net::OK;
+
+  BlockedRequest blocked_request;
+  blocked_request.callback = std::move(callback);
+  blocked_request.override_response_headers = override_response_headers;
+  blocked_request.status_line = original_response_headers
+                                    ? original_response_headers->GetStatusLine()
+                                    : std::string();
+  blocked_requests_[request_info->id] = std::move(blocked_request);
+
+  v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
+  v8::HandleScope handle_scope(isolate);
+  gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
+  FillDetails(&details, request_info, request);
+
+  ResponseCallback response =
+      base::BindOnce(&WebRequest::OnHeadersReceivedListenerResult,
+                     base::Unretained(this), request_info->id);
+  info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
+  return net::ERR_IO_PENDING;
+}
+
+void WebRequest::OnHeadersReceivedListenerResult(
+    uint64_t id,
+    v8::Local<v8::Value> response) {
+  const auto iter = blocked_requests_.find(id);
+  if (iter == std::end(blocked_requests_))
+    return;
+
+  auto& request = iter->second;
+
+  int result = net::OK;
+  bool user_modified_headers = false;
+  scoped_refptr<net::HttpResponseHeaders> override_headers(
+      new net::HttpResponseHeaders(""));
+  if (response->IsObject()) {
+    v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
+    gin::Dictionary dict(isolate, response.As<v8::Object>());
+
+    bool cancel = false;
+    dict.Get("cancel", &cancel);
+    if (cancel) {
+      result = net::ERR_BLOCKED_BY_CLIENT;
+    } else {
+      std::string status_line;
+      if (!dict.Get("statusLine", &status_line))
+        status_line = request.status_line;
+      v8::Local<v8::Value> value;
+      if (dict.Get("responseHeaders", &value) && value->IsObject()) {
+        user_modified_headers = true;
+        override_headers->ReplaceStatusLine(status_line);
+        gin::Converter<net::HttpResponseHeaders*>::FromV8(
+            isolate, value, override_headers.get());
+      }
+    }
+  }
+
+  if (user_modified_headers)
+    request.override_response_headers->swap(override_headers);
+
+  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
+      FROM_HERE, base::BindOnce(std::move(request.callback), result));
+  blocked_requests_.erase(iter);
 }
 
 void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info,
@@ -371,7 +600,7 @@ void WebRequest::OnResponseStarted(extensions::WebRequestInfo* info,
 void WebRequest::OnErrorOccurred(extensions::WebRequestInfo* info,
                                  const network::ResourceRequest& request,
                                  int net_error) {
-  callbacks_.erase(info->id);
+  blocked_requests_.erase(info->id);
 
   HandleSimpleEvent(SimpleEvent::kOnErrorOccurred, info, request, net_error);
 }
@@ -379,13 +608,13 @@ void WebRequest::OnErrorOccurred(extensions::WebRequestInfo* info,
 void WebRequest::OnCompleted(extensions::WebRequestInfo* info,
                              const network::ResourceRequest& request,
                              int net_error) {
-  callbacks_.erase(info->id);
+  blocked_requests_.erase(info->id);
 
   HandleSimpleEvent(SimpleEvent::kOnCompleted, info, request, net_error);
 }
 
 void WebRequest::OnRequestWillBeDestroyed(extensions::WebRequestInfo* info) {
-  callbacks_.erase(info->id);
+  blocked_requests_.erase(info->id);
 }
 
 template <WebRequest::SimpleEvent event>
@@ -479,62 +708,6 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event,
   info.listener.Run(gin::ConvertToV8(isolate, details));
 }
 
-template <typename Out, typename... Args>
-int WebRequest::HandleResponseEvent(ResponseEvent event,
-                                    extensions::WebRequestInfo* request_info,
-                                    net::CompletionOnceCallback callback,
-                                    Out out,
-                                    Args... args) {
-  const auto iter = response_listeners_.find(event);
-  if (iter == std::end(response_listeners_))
-    return net::OK;
-
-  const auto& info = iter->second;
-  if (!info.filter.MatchesRequest(request_info))
-    return net::OK;
-
-  callbacks_[request_info->id] = std::move(callback);
-
-  v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
-  v8::HandleScope handle_scope(isolate);
-  gin_helper::Dictionary details(isolate, v8::Object::New(isolate));
-  FillDetails(&details, request_info, args...);
-
-  ResponseCallback response =
-      base::BindOnce(&WebRequest::OnListenerResult<Out>, base::Unretained(this),
-                     request_info->id, out);
-  info.listener.Run(gin::ConvertToV8(isolate, details), std::move(response));
-  return net::ERR_IO_PENDING;
-}
-
-template <typename T>
-void WebRequest::OnListenerResult(uint64_t id,
-                                  T out,
-                                  v8::Local<v8::Value> response) {
-  const auto iter = callbacks_.find(id);
-  if (iter == std::end(callbacks_))
-    return;
-
-  int result = net::OK;
-  if (response->IsObject()) {
-    v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
-    gin::Dictionary dict(isolate, response.As<v8::Object>());
-
-    bool cancel = false;
-    dict.Get("cancel", &cancel);
-    if (cancel)
-      result = net::ERR_BLOCKED_BY_CLIENT;
-    else
-      ReadFromResponse(isolate, &dict, out);
-  }
-
-  // The ProxyingURLLoaderFactory expects the callback to be executed
-  // asynchronously, because it used to work on IO thread before NetworkService.
-  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
-      FROM_HERE, base::BindOnce(std::move(callbacks_[id]), result));
-  callbacks_.erase(iter);
-}
-
 // static
 gin::Handle<WebRequest> WebRequest::FromOrCreate(
     v8::Isolate* isolate,

+ 29 - 9
shell/browser/api/electron_api_web_request.h

@@ -84,6 +84,10 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
   WebRequest(v8::Isolate* isolate, content::BrowserContext* browser_context);
   ~WebRequest() override;
 
+  // Contains info about requests that are blocked waiting for a response from
+  // the user.
+  struct BlockedRequest;
+
   enum class SimpleEvent {
     kOnSendHeaders,
     kOnBeforeRedirect,
@@ -91,6 +95,7 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
     kOnCompleted,
     kOnErrorOccurred,
   };
+
   enum class ResponseEvent {
     kOnBeforeRequest,
     kOnBeforeSendHeaders,
@@ -113,15 +118,30 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
   void HandleSimpleEvent(SimpleEvent event,
                          extensions::WebRequestInfo* info,
                          Args... args);
-  template <typename Out, typename... Args>
-  int HandleResponseEvent(ResponseEvent event,
-                          extensions::WebRequestInfo* info,
-                          net::CompletionOnceCallback callback,
-                          Out out,
-                          Args... args);
 
-  template <typename T>
-  void OnListenerResult(uint64_t id, T out, v8::Local<v8::Value> response);
+  int HandleOnBeforeRequestResponseEvent(
+      extensions::WebRequestInfo* info,
+      const network::ResourceRequest& request,
+      net::CompletionOnceCallback callback,
+      GURL* redirect_url);
+  int HandleOnBeforeSendHeadersResponseEvent(
+      extensions::WebRequestInfo* info,
+      const network::ResourceRequest& request,
+      BeforeSendHeadersCallback callback,
+      net::HttpRequestHeaders* headers);
+  int HandleOnHeadersReceivedResponseEvent(
+      extensions::WebRequestInfo* info,
+      const network::ResourceRequest& request,
+      net::CompletionOnceCallback callback,
+      const net::HttpResponseHeaders* original_response_headers,
+      scoped_refptr<net::HttpResponseHeaders>* override_response_headers);
+
+  void OnBeforeRequestListenerResult(uint64_t id,
+                                     v8::Local<v8::Value> response);
+  void OnBeforeSendHeadersListenerResult(uint64_t id,
+                                         v8::Local<v8::Value> response);
+  void OnHeadersReceivedListenerResult(uint64_t id,
+                                       v8::Local<v8::Value> response);
 
   class RequestFilter {
    public:
@@ -164,7 +184,7 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
 
   std::map<SimpleEvent, SimpleListenerInfo> simple_listeners_;
   std::map<ResponseEvent, ResponseListenerInfo> response_listeners_;
-  std::map<uint64_t, net::CompletionOnceCallback> callbacks_;
+  std::map<uint64_t, BlockedRequest> blocked_requests_;
 
   // Weak-ref, it manages us.
   raw_ptr<content::BrowserContext> browser_context_;

+ 2 - 2
spec/api-web-request-spec.ts

@@ -328,7 +328,7 @@ describe('webRequest module', () => {
         ses.webRequest.onBeforeSendHeaders((details, callback) => {
           const requestHeaders = details.requestHeaders;
           requestHeaders.Accept = '*/*;test/header';
-          callback({ requestHeaders: requestHeaders });
+          callback({ requestHeaders });
         });
         const { data } = await ajax('no-cors://fake-host/redirect');
         expect(data).to.equal('header-received');
@@ -341,7 +341,7 @@ describe('webRequest module', () => {
       ses.webRequest.onBeforeSendHeaders((details, callback) => {
         const requestHeaders = details.requestHeaders;
         requestHeaders.Origin = 'http://new-origin';
-        callback({ requestHeaders: requestHeaders });
+        callback({ requestHeaders });
       });
       const { data } = await ajax(defaultURL);
       expect(data).to.equal('/new/origin');