View Javadoc
1   /**
2    * This file Copyright (c) 2020 Magnolia International
3    * Ltd.  (http://www.magnolia-cms.com). All rights reserved.
4    *
5    *
6    * This file is dual-licensed under both the Magnolia
7    * Network Agreement and the GNU General Public License.
8    * You may elect to use one or the other of these licenses.
9    *
10   * This file is distributed in the hope that it will be
11   * useful, but AS-IS and WITHOUT ANY WARRANTY; without even the
12   * implied warranty of MERCHANTABILITY or FITNESS FOR A
13   * PARTICULAR PURPOSE, TITLE, or NONINFRINGEMENT.
14   * Redistribution, except as permitted by whichever of the GPL
15   * or MNA you select, is prohibited.
16   *
17   * 1. For the GPL license (GPL), you can redistribute and/or
18   * modify this file under the terms of the GNU General
19   * Public License, Version 3, as published by the Free Software
20   * Foundation.  You should have received a copy of the GNU
21   * General Public License, Version 3 along with this program;
22   * if not, write to the Free Software Foundation, Inc., 51
23   * Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
24   *
25   * 2. For the Magnolia Network Agreement (MNA), this file
26   * and the accompanying materials are made available under the
27   * terms of the MNA which accompanies this distribution, and
28   * is available at http://www.magnolia-cms.com/mna.html
29   *
30   * Any modifications to this file must keep this entire header
31   * intact.
32   *
33   */
34  package info.magnolia.cors;
35  
36  import static info.magnolia.cors.AbstractCorsFilter.Headers.*;
37  
38  import info.magnolia.cms.filters.AbstractMgnlFilter;
39  
40  import java.io.IOException;
41  import java.util.ArrayList;
42  import java.util.Collection;
43  import java.util.Enumeration;
44  import java.util.HashSet;
45  import java.util.Set;
46  
47  import javax.inject.Inject;
48  import javax.servlet.FilterChain;
49  import javax.servlet.ServletException;
50  import javax.servlet.http.HttpServletRequest;
51  import javax.servlet.http.HttpServletResponse;
52  
53  import org.apache.commons.lang3.StringUtils;
54  
55  import com.google.common.base.Joiner;
56  
57  /**
58   * Filter that handles CORS requests.
59   */
60  public class CorsResponseFilter extends AbstractMgnlFilter {
61  
62      private static final String WILDCARD = "*";
63  
64      private final CorsConfiguration configuration;
65  
66      @Inject
67      public CorsResponseFilter(final CorsConfiguration configuration) {
68          this.configuration = configuration;
69      }
70  
71      @Override
72      public void doFilter(final HttpServletRequest request, final HttpServletResponse response, final FilterChain chain) throws IOException, ServletException {
73          final String origin = request.getHeader(ORIGIN.getName());
74          if (!isOriginAllowed(origin)) {
75              throw new CorsException(String.format("Origin [%s] not allowed", origin));
76          }
77          if (isPreflightRequest(request)) {
78              final String requestMethod = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD.getName());
79              if (StringUtils.isBlank(requestMethod)) {
80                  throw new CorsException(String.format("Header [%s] value must not be null or empty", ACCESS_CONTROL_REQUEST_METHOD.getName()));
81              }
82              if (!isMethodAllowed(requestMethod)) {
83                  throw new CorsException(String.format("Method [%s] is not allowed", requestMethod.toUpperCase()));
84              }
85              final Set<String> requestHeaders = accessControlRequestHeaders(request);
86              if (!areHeadersAllowed(requestHeaders)) {
87                  requestHeaders.removeAll(configuration.getAllowedHeaders());
88                  throw new CorsException(String.format("Some of the request headers %s are not allowed", requestHeaders));
89              }
90  
91              setVaryHeader(response, ACCESS_CONTROL_REQUEST_METHOD.getName());
92              setVaryHeader(response, ACCESS_CONTROL_REQUEST_HEADERS.getName());
93  
94              if (configuration.getMaxAge() > 0) {
95                  response.addHeader(ACCESS_CONTROL_MAX_AGE.getName(), String.valueOf(configuration.getMaxAge()));
96              }
97              response.addHeader(ACCESS_CONTROL_ALLOW_METHODS.getName(), Joiner.on(',').join(configuration.getAllowedMethods()));
98              response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS.getName(), Joiner.on(',').join(configuration.getAllowedHeaders()));
99              response.setStatus(204);
100         } else {
101             final String method = request.getMethod();
102             if (!isMethodAllowed(method)) {
103                 throw new CorsException(String.format("Method [%s] is not allowed", method.toUpperCase()));
104             }
105         }
106         addStandardHeaders(request, response);
107     }
108 
109     private boolean isPreflightRequest(final HttpServletRequest request) {
110         final String method = request.getMethod();
111         final String requestMethodHeader = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD.getName());
112 
113         return AbstractCorsFilter.OPTIONS_METHOD.equals(method)
114                 && requestMethodHeader != null;
115     }
116 
117     private boolean isMethodAllowed(final String requestMethod) {
118         if (configuration.getAllowedMethods().contains(WILDCARD)) {
119             return true;
120         }
121         return configuration.getAllowedMethods().contains(requestMethod);
122     }
123 
124     private Set<String> accessControlRequestHeaders(final HttpServletRequest request) {
125         final Set<String> result = new HashSet<>();
126         final Enumeration<String> headers = request.getHeaders(ACCESS_CONTROL_REQUEST_HEADERS.getName());
127         while (headers.hasMoreElements()) {
128             result.add(headers.nextElement().toLowerCase());
129         }
130         return result;
131     }
132 
133     private boolean areHeadersAllowed(final Set<String> requestHeaders) {
134         if (configuration.getAllowedHeaders().contains(WILDCARD)) {
135             return true;
136         }
137         return configuration.getAllowedHeaders().containsAll(requestHeaders);
138     }
139 
140     private void addStandardHeaders(final HttpServletRequest request, final HttpServletResponse response) {
141         final String origin = request.getHeader(ORIGIN.getName());
142         final boolean anyOriginAllowed = configuration.getAllowedOrigins().contains(WILDCARD);
143 
144         if (!anyOriginAllowed) {
145             setVaryHeader(response, ORIGIN.getName());
146         }
147 
148         if (anyOriginAllowed) {
149             response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN.getName(), WILDCARD);
150         } else {
151             response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN.getName(), origin);
152         }
153 
154         if (configuration.isSupportsCredentials()) {
155             response.addHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS.getName(), "true");
156         }
157     }
158 
159     private boolean isOriginAllowed(final String origin) {
160         if (origin == null || origin.isEmpty()) {
161             return false;
162         }
163         if (configuration.getAllowedOrigins().contains(WILDCARD)) {
164             return true;
165         }
166         return configuration.getAllowedOrigins().contains(origin);
167     }
168 
169     private void setVaryHeader(final HttpServletResponse response, final String name) {
170         final Collection<String> varyHeaders = response.getHeaders(VARY.getName());
171         final String headerName = name.trim();
172 
173         if (varyHeaders.size() == 1 && varyHeaders.stream().anyMatch(WILDCARD::equals)) {
174             return;
175         }
176 
177         if (varyHeaders.size() == 0) {
178             response.addHeader(VARY.getName(), headerName);
179             return;
180         }
181 
182         if (WILDCARD.equals(headerName)) {
183             response.setHeader(VARY.getName(), WILDCARD);
184             return;
185         }
186 
187         if (varyHeaders.stream().map(String::trim).anyMatch(headerName::equals)) {
188             return;
189         }
190 
191         final ArrayList<String> headerValues = new ArrayList<>(varyHeaders);
192         headerValues.add(headerName);
193         response.setHeader(VARY.getName(), Joiner.on(',').join(headerValues));
194     }
195 }