1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package info.magnolia.cors;
35
36 import static info.magnolia.cors.AbstractCorsFilter.Headers.*;
37
38 import info.magnolia.cms.filters.AbstractMgnlFilter;
39
40 import java.io.IOException;
41 import java.util.ArrayList;
42 import java.util.Collection;
43 import java.util.Enumeration;
44 import java.util.HashSet;
45 import java.util.Set;
46
47 import javax.inject.Inject;
48 import javax.servlet.FilterChain;
49 import javax.servlet.ServletException;
50 import javax.servlet.http.HttpServletRequest;
51 import javax.servlet.http.HttpServletResponse;
52
53 import org.apache.commons.lang3.StringUtils;
54
55 import com.google.common.base.Joiner;
56
57
58
59
60 public class CorsResponseFilter extends AbstractMgnlFilter {
61
62 private static final String WILDCARD = "*";
63
64 private final CorsConfiguration configuration;
65
66 @Inject
67 public CorsResponseFilter(final CorsConfiguration configuration) {
68 this.configuration = configuration;
69 }
70
71 @Override
72 public void doFilter(final HttpServletRequest request, final HttpServletResponse response, final FilterChain chain) throws IOException, ServletException {
73 final String origin = request.getHeader(ORIGIN.getName());
74 if (!isOriginAllowed(origin)) {
75 throw new CorsException(String.format("Origin [%s] not allowed", origin));
76 }
77 if (isPreflightRequest(request)) {
78 final String requestMethod = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD.getName());
79 if (StringUtils.isBlank(requestMethod)) {
80 throw new CorsException(String.format("Header [%s] value must not be null or empty", ACCESS_CONTROL_REQUEST_METHOD.getName()));
81 }
82 if (!isMethodAllowed(requestMethod)) {
83 throw new CorsException(String.format("Method [%s] is not allowed", requestMethod.toUpperCase()));
84 }
85 final Set<String> requestHeaders = accessControlRequestHeaders(request);
86 if (!areHeadersAllowed(requestHeaders)) {
87 requestHeaders.removeAll(configuration.getAllowedHeaders());
88 throw new CorsException(String.format("Some of the request headers %s are not allowed", requestHeaders));
89 }
90
91 setVaryHeader(response, ACCESS_CONTROL_REQUEST_METHOD.getName());
92 setVaryHeader(response, ACCESS_CONTROL_REQUEST_HEADERS.getName());
93
94 if (configuration.getMaxAge() > 0) {
95 response.addHeader(ACCESS_CONTROL_MAX_AGE.getName(), String.valueOf(configuration.getMaxAge()));
96 }
97 response.addHeader(ACCESS_CONTROL_ALLOW_METHODS.getName(), Joiner.on(',').join(configuration.getAllowedMethods()));
98 response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS.getName(), Joiner.on(',').join(configuration.getAllowedHeaders()));
99 response.setStatus(204);
100 } else {
101 final String method = request.getMethod();
102 if (!isMethodAllowed(method)) {
103 throw new CorsException(String.format("Method [%s] is not allowed", method.toUpperCase()));
104 }
105 }
106 addStandardHeaders(request, response);
107 }
108
109 private boolean isPreflightRequest(final HttpServletRequest request) {
110 final String method = request.getMethod();
111 final String requestMethodHeader = request.getHeader(ACCESS_CONTROL_REQUEST_METHOD.getName());
112
113 return AbstractCorsFilter.OPTIONS_METHOD.equals(method)
114 && requestMethodHeader != null;
115 }
116
117 private boolean isMethodAllowed(final String requestMethod) {
118 if (configuration.getAllowedMethods().contains(WILDCARD)) {
119 return true;
120 }
121 return configuration.getAllowedMethods().contains(requestMethod);
122 }
123
124 private Set<String> accessControlRequestHeaders(final HttpServletRequest request) {
125 final Set<String> result = new HashSet<>();
126 final Enumeration<String> headers = request.getHeaders(ACCESS_CONTROL_REQUEST_HEADERS.getName());
127 while (headers.hasMoreElements()) {
128 result.add(headers.nextElement().toLowerCase());
129 }
130 return result;
131 }
132
133 private boolean areHeadersAllowed(final Set<String> requestHeaders) {
134 if (configuration.getAllowedHeaders().contains(WILDCARD)) {
135 return true;
136 }
137 return configuration.getAllowedHeaders().containsAll(requestHeaders);
138 }
139
140 private void addStandardHeaders(final HttpServletRequest request, final HttpServletResponse response) {
141 final String origin = request.getHeader(ORIGIN.getName());
142 final boolean anyOriginAllowed = configuration.getAllowedOrigins().contains(WILDCARD);
143
144 if (!anyOriginAllowed) {
145 setVaryHeader(response, ORIGIN.getName());
146 }
147
148 if (anyOriginAllowed) {
149 response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN.getName(), WILDCARD);
150 } else {
151 response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN.getName(), origin);
152 }
153
154 if (configuration.isSupportsCredentials()) {
155 response.addHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS.getName(), "true");
156 }
157 }
158
159 private boolean isOriginAllowed(final String origin) {
160 if (origin == null || origin.isEmpty()) {
161 return false;
162 }
163 if (configuration.getAllowedOrigins().contains(WILDCARD)) {
164 return true;
165 }
166 return configuration.getAllowedOrigins().contains(origin);
167 }
168
169 private void setVaryHeader(final HttpServletResponse response, final String name) {
170 final Collection<String> varyHeaders = response.getHeaders(VARY.getName());
171 final String headerName = name.trim();
172
173 if (varyHeaders.size() == 1 && varyHeaders.stream().anyMatch(WILDCARD::equals)) {
174 return;
175 }
176
177 if (varyHeaders.size() == 0) {
178 response.addHeader(VARY.getName(), headerName);
179 return;
180 }
181
182 if (WILDCARD.equals(headerName)) {
183 response.setHeader(VARY.getName(), WILDCARD);
184 return;
185 }
186
187 if (varyHeaders.stream().map(String::trim).anyMatch(headerName::equals)) {
188 return;
189 }
190
191 final ArrayList<String> headerValues = new ArrayList<>(varyHeaders);
192 headerValues.add(headerName);
193 response.setHeader(VARY.getName(), Joiner.on(',').join(headerValues));
194 }
195 }