Browse Source

feat: add types to webRequest filter (#30914)

Milan Burda 2 years ago
parent
commit
ed7b5c44a2

+ 1 - 0
docs/api/structures/web-request-filter.md

@@ -1,3 +1,4 @@
 # WebRequestFilter Object
 
 * `urls` string[] - Array of [URL patterns](https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Match_patterns) that will be used to filter out the requests that do not match the URL patterns.
+* `types` String[] (optional) - Array of types that will be used to filter out the requests that do not match the types. When not specified, all types will be matched. Can be `mainFrame`, `subFrame`, `stylesheet`, `script`, `image`, `font`, `object`, `xhr`, `ping`, `cspReport`, `media` or `webSocket`.

+ 89 - 23
shell/browser/api/electron_api_web_request.cc

@@ -94,17 +94,34 @@ struct UserData : public base::SupportsUserData::Data {
   WebRequest* data;
 };
 
-// Test whether the URL of |request| matches |patterns|.
-bool MatchesFilterCondition(extensions::WebRequestInfo* info,
-                            const std::set<URLPattern>& patterns) {
-  if (patterns.empty())
-    return true;
-
-  for (const auto& pattern : patterns) {
-    if (pattern.MatchesURL(info->url))
-      return true;
+extensions::WebRequestResourceType ParseResourceType(const std::string& value) {
+  if (value == "mainFrame") {
+    return extensions::WebRequestResourceType::MAIN_FRAME;
+  } else if (value == "subFrame") {
+    return extensions::WebRequestResourceType::SUB_FRAME;
+  } else if (value == "stylesheet") {
+    return extensions::WebRequestResourceType::STYLESHEET;
+  } else if (value == "script") {
+    return extensions::WebRequestResourceType::SCRIPT;
+  } else if (value == "image") {
+    return extensions::WebRequestResourceType::IMAGE;
+  } else if (value == "font") {
+    return extensions::WebRequestResourceType::FONT;
+  } else if (value == "object") {
+    return extensions::WebRequestResourceType::OBJECT;
+  } else if (value == "xhr") {
+    return extensions::WebRequestResourceType::XHR;
+  } else if (value == "ping") {
+    return extensions::WebRequestResourceType::PING;
+  } else if (value == "cspReport") {
+    return extensions::WebRequestResourceType::CSP_REPORT;
+  } else if (value == "media") {
+    return extensions::WebRequestResourceType::MEDIA;
+  } else if (value == "webSocket") {
+    return extensions::WebRequestResourceType::WEB_SOCKET;
+  } else {
+    return extensions::WebRequestResourceType::OTHER;
   }
-  return false;
 }
 
 // Convert HttpResponseHeaders to V8.
@@ -247,17 +264,54 @@ void ReadFromResponse(v8::Isolate* isolate,
 
 gin::WrapperInfo WebRequest::kWrapperInfo = {gin::kEmbedderNativeGin};
 
-WebRequest::SimpleListenerInfo::SimpleListenerInfo(
-    std::set<URLPattern> patterns_,
-    SimpleListener listener_)
-    : url_patterns(std::move(patterns_)), listener(listener_) {}
+WebRequest::RequestFilter::RequestFilter(
+    std::set<URLPattern> url_patterns,
+    std::set<extensions::WebRequestResourceType> types)
+    : url_patterns_(std::move(url_patterns)), types_(std::move(types)) {}
+WebRequest::RequestFilter::RequestFilter(const RequestFilter&) = default;
+WebRequest::RequestFilter::RequestFilter() = default;
+WebRequest::RequestFilter::~RequestFilter() = default;
+
+void WebRequest::RequestFilter::AddUrlPattern(URLPattern pattern) {
+  url_patterns_.emplace(std::move(pattern));
+}
+
+void WebRequest::RequestFilter::AddType(
+    extensions::WebRequestResourceType type) {
+  types_.insert(type);
+}
+
+bool WebRequest::RequestFilter::MatchesURL(const GURL& url) const {
+  if (url_patterns_.empty())
+    return true;
+
+  for (const auto& pattern : url_patterns_) {
+    if (pattern.MatchesURL(url))
+      return true;
+  }
+  return false;
+}
+
+bool WebRequest::RequestFilter::MatchesType(
+    extensions::WebRequestResourceType type) const {
+  return types_.empty() || types_.find(type) != types_.end();
+}
+
+bool WebRequest::RequestFilter::MatchesRequest(
+    extensions::WebRequestInfo* info) const {
+  return MatchesURL(info->url) && MatchesType(info->web_request_type);
+}
+
+WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_,
+                                                   SimpleListener listener_)
+    : filter(std::move(filter_)), listener(listener_) {}
 WebRequest::SimpleListenerInfo::SimpleListenerInfo() = default;
 WebRequest::SimpleListenerInfo::~SimpleListenerInfo() = default;
 
 WebRequest::ResponseListenerInfo::ResponseListenerInfo(
-    std::set<URLPattern> patterns_,
+    RequestFilter filter_,
     ResponseListener listener_)
-    : url_patterns(std::move(patterns_)), listener(listener_) {}
+    : filter(std::move(filter_)), listener(listener_) {}
 WebRequest::ResponseListenerInfo::ResponseListenerInfo() = default;
 WebRequest::ResponseListenerInfo::~ResponseListenerInfo() = default;
 
@@ -392,8 +446,8 @@ void WebRequest::SetListener(Event event,
                              gin::Arguments* args) {
   v8::Local<v8::Value> arg;
 
-  // { urls }.
-  std::set<std::string> filter_patterns;
+  // { urls, types }.
+  std::set<std::string> filter_patterns, filter_types;
   gin::Dictionary dict(args->isolate());
   if (args->GetNext(&arg) && !arg->IsFunction()) {
     // Note that gin treats Function as Dictionary when doing conversions, so we
@@ -404,16 +458,18 @@ void WebRequest::SetListener(Event event,
         args->ThrowTypeError("Parameter 'filter' must have property 'urls'.");
         return;
       }
+      dict.Get("types", &filter_types);
       args->GetNext(&arg);
     }
   }
 
-  std::set<URLPattern> patterns;
+  RequestFilter filter;
+
   for (const std::string& filter_pattern : filter_patterns) {
     URLPattern pattern(URLPattern::SCHEME_ALL);
     const URLPattern::ParseResult result = pattern.Parse(filter_pattern);
     if (result == URLPattern::ParseResult::kSuccess) {
-      patterns.insert(pattern);
+      filter.AddUrlPattern(std::move(pattern));
     } else {
       const char* error_type = URLPattern::GetParseResultString(result);
       args->ThrowTypeError("Invalid url pattern " + filter_pattern + ": " +
@@ -422,6 +478,16 @@ void WebRequest::SetListener(Event event,
     }
   }
 
+  for (const std::string& filter_type : filter_types) {
+    auto type = ParseResourceType(filter_type);
+    if (type != extensions::WebRequestResourceType::OTHER) {
+      filter.AddType(type);
+    } else {
+      args->ThrowTypeError("Invalid type " + filter_type);
+      return;
+    }
+  }
+
   // Function or null.
   Listener listener;
   if (arg.IsEmpty() ||
@@ -433,7 +499,7 @@ void WebRequest::SetListener(Event event,
   if (listener.is_null())
     listeners->erase(event);
   else
-    (*listeners)[event] = {std::move(patterns), std::move(listener)};
+    (*listeners)[event] = {std::move(filter), std::move(listener)};
 }
 
 template <typename... Args>
@@ -445,7 +511,7 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event,
     return;
 
   const auto& info = iter->second;
-  if (!MatchesFilterCondition(request_info, info.url_patterns))
+  if (!info.filter.MatchesRequest(request_info))
     return;
 
   v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
@@ -466,7 +532,7 @@ int WebRequest::HandleResponseEvent(ResponseEvent event,
     return net::OK;
 
   const auto& info = iter->second;
-  if (!MatchesFilterCondition(request_info, info.url_patterns))
+  if (!info.filter.MatchesRequest(request_info))
     return net::OK;
 
   callbacks_[request_info->id] = std::move(callback);

+ 25 - 4
shell/browser/api/electron_api_web_request.h

@@ -123,20 +123,41 @@ class WebRequest : public gin::Wrappable<WebRequest>, public WebRequestAPI {
   template <typename T>
   void OnListenerResult(uint64_t id, T out, v8::Local<v8::Value> response);
 
+  class RequestFilter {
+   public:
+    RequestFilter(std::set<URLPattern>,
+                  std::set<extensions::WebRequestResourceType>);
+    RequestFilter(const RequestFilter&);
+    RequestFilter();
+    ~RequestFilter();
+
+    void AddUrlPattern(URLPattern pattern);
+    void AddType(extensions::WebRequestResourceType type);
+
+    bool MatchesRequest(extensions::WebRequestInfo* info) const;
+
+   private:
+    bool MatchesURL(const GURL& url) const;
+    bool MatchesType(extensions::WebRequestResourceType type) const;
+
+    std::set<URLPattern> url_patterns_;
+    std::set<extensions::WebRequestResourceType> types_;
+  };
+
   struct SimpleListenerInfo {
-    std::set<URLPattern> url_patterns;
+    RequestFilter filter;
     SimpleListener listener;
 
-    SimpleListenerInfo(std::set<URLPattern>, SimpleListener);
+    SimpleListenerInfo(RequestFilter, SimpleListener);
     SimpleListenerInfo();
     ~SimpleListenerInfo();
   };
 
   struct ResponseListenerInfo {
-    std::set<URLPattern> url_patterns;
+    RequestFilter filter;
     ResponseListener listener;
 
-    ResponseListenerInfo(std::set<URLPattern>, ResponseListener);
+    ResponseListenerInfo(RequestFilter, ResponseListener);
     ResponseListenerInfo();
     ~ResponseListenerInfo();
   };

+ 19 - 8
spec/api-web-request-spec.ts

@@ -63,25 +63,36 @@ describe('webRequest module', () => {
       ses.webRequest.onBeforeRequest(null);
     });
 
+    const cancel = (details: Electron.OnBeforeRequestListenerDetails, callback: (response: Electron.CallbackResponse) => void) => {
+      callback({ cancel: true });
+    };
+
     it('can cancel the request', async () => {
-      ses.webRequest.onBeforeRequest((details, callback) => {
-        callback({
-          cancel: true
-        });
-      });
+      ses.webRequest.onBeforeRequest(cancel);
       await expect(ajax(defaultURL)).to.eventually.be.rejected();
     });
 
     it('can filter URLs', async () => {
       const filter = { urls: [defaultURL + 'filter/*'] };
-      ses.webRequest.onBeforeRequest(filter, (details, callback) => {
-        callback({ cancel: true });
-      });
+      ses.webRequest.onBeforeRequest(filter, cancel);
       const { data } = await ajax(`${defaultURL}nofilter/test`);
       expect(data).to.equal('/nofilter/test');
       await expect(ajax(`${defaultURL}filter/test`)).to.eventually.be.rejected();
     });
 
+    it('can filter URLs and types', async () => {
+      const filter1: Electron.WebRequestFilter = { urls: [defaultURL + 'filter/*'], types: ['xhr'] };
+      ses.webRequest.onBeforeRequest(filter1, cancel);
+      const { data } = await ajax(`${defaultURL}nofilter/test`);
+      expect(data).to.equal('/nofilter/test');
+      await expect(ajax(`${defaultURL}filter/test`)).to.eventually.be.rejected();
+
+      const filter2: Electron.WebRequestFilter = { urls: [defaultURL + 'filter/*'], types: ['stylesheet'] };
+      ses.webRequest.onBeforeRequest(filter2, cancel);
+      expect((await ajax(`${defaultURL}nofilter/test`)).data).to.equal('/nofilter/test');
+      expect((await ajax(`${defaultURL}filter/test`)).data).to.equal('/filter/test');
+    });
+
     it('receives details object', async () => {
       ses.webRequest.onBeforeRequest((details, callback) => {
         expect(details.id).to.be.a('number');