1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package info.magnolia.cors;
35
36 import info.magnolia.cms.filters.AbstractMgnlFilter;
37 import info.magnolia.cms.filters.MgnlFilter;
38
39 import java.io.IOException;
40 import java.net.MalformedURLException;
41 import java.net.URI;
42 import java.net.URISyntaxException;
43 import java.net.URL;
44 import java.util.Optional;
45
46 import javax.servlet.FilterChain;
47 import javax.servlet.ServletException;
48 import javax.servlet.http.HttpServletRequest;
49 import javax.servlet.http.HttpServletResponse;
50
51 import org.apache.commons.lang3.StringUtils;
52 import org.slf4j.Logger;
53 import org.slf4j.LoggerFactory;
54
55 import com.machinezoo.noexception.Exceptions;
56
57
58
59
60
61
62
63
64 public abstract class AbstractCorsFilter extends AbstractMgnlFilter {
65
66 public static final String OPTIONS_METHOD = "OPTIONS";
67
68 private static final Logger log = LoggerFactory.getLogger(AbstractCorsFilter.class);
69
70 @Override
71 public void doFilter(final HttpServletRequest request, final HttpServletResponse response, final FilterChain filterChain) throws IOException, ServletException {
72 final String current = currentHost(request);
73 final RequestType requestType = RequestType.from(current, request);
74 try {
75 switch (requestType) {
76 case NOT_CORS:
77 filterChain.doFilter(request, response);
78 break;
79 case PRE_FLIGHT:
80 final Optional<MgnlFilter> corsResponseFilter = getCorsResponseFilter();
81 if (corsResponseFilter.isPresent()) {
82 corsResponseFilter.get().doFilter(request, response, filterChain);
83 } else {
84 filterChain.doFilter(request, response);
85 }
86 break;
87 case CORS:
88 getCorsResponseFilter().ifPresent(Exceptions.wrap().consumer(filter -> filter.doFilter(request, response, filterChain)));
89 filterChain.doFilter(request, response);
90 break;
91 default:
92 handleInvalid(response);
93 break;
94 }
95 } catch (CorsException e) {
96 log.warn("CORS failed due to: {}", e.getMessage());
97 handleInvalid(response);
98 }
99 }
100
101 protected abstract Optional<MgnlFilter> getCorsResponseFilter();
102
103 private String currentHost(final HttpServletRequest request) throws MalformedURLException {
104 final URL url = new URL(request.getRequestURL().toString());
105 final String protocol = url.getProtocol();
106 final String authority = url.getAuthority();
107 return String.format("%s://%s", protocol, authority);
108 }
109
110 private void handleInvalid(final HttpServletResponse response) {
111 response.setContentType("text/plain");
112 response.setStatus(HttpServletResponse.SC_FORBIDDEN);
113 response.resetBuffer();
114 }
115
116
117
118
119 public enum RequestType {
120 CORS, PRE_FLIGHT, NOT_CORS, INVALID_CORS;
121
122 public static RequestType from(final String currentHost, final HttpServletRequest request) {
123 final String origin = request.getHeader(Headers.ORIGIN.getName());
124 if (origin == null) {
125 return RequestType.NOT_CORS;
126 } else if (origin.isEmpty() || !isValidOrigin(origin)) {
127 return RequestType.INVALID_CORS;
128 } else if (isSameOrigin(currentHost, origin)) {
129 return RequestType.NOT_CORS;
130 } else {
131 final String method = request.getMethod();
132 if (method != null) {
133 if (OPTIONS_METHOD.equals(method)) {
134 final String accessControlRequestMethodHeader = request.getHeader(Headers.ACCESS_CONTROL_REQUEST_METHOD.getName());
135 if (StringUtils.isNotBlank(accessControlRequestMethodHeader)) {
136 return RequestType.PRE_FLIGHT;
137 } else if (accessControlRequestMethodHeader != null && accessControlRequestMethodHeader.isEmpty()) {
138 return RequestType.INVALID_CORS;
139 } else {
140 return RequestType.CORS;
141 }
142 } else {
143 return RequestType.CORS;
144 }
145 }
146 }
147 return RequestType.INVALID_CORS;
148 }
149
150 private static boolean isValidOrigin(final String origin) {
151 if (StringUtils.contains(origin, '%')) {
152 return false;
153 }
154 final URI uri;
155 try {
156 uri = new URI(origin);
157 } catch (URISyntaxException e) {
158 return false;
159 }
160 return uri.getScheme() != null;
161 }
162
163 private static boolean isSameOrigin(final String currentUri, final String origin) {
164 return currentUri.equals(origin);
165 }
166 }
167
168
169
170
171 public enum Headers {
172 ACCESS_CONTROL_REQUEST_METHOD("Access-Control-Request-Method"),
173 ACCESS_CONTROL_REQUEST_HEADERS("Access-Control-Request-Headers"),
174 ACCESS_CONTROL_ALLOW_ORIGIN("Access-Control-Allow-Origin"),
175 ACCESS_CONTROL_ALLOW_CREDENTIALS("Access-Control-Allow-Credentials"),
176 ACCESS_CONTROL_MAX_AGE("Access-Control-Max-Age"),
177 ACCESS_CONTROL_ALLOW_METHODS("Access-Control-Allow-Methods"),
178 ACCESS_CONTROL_ALLOW_HEADERS("Access-Control-Allow-Headers"),
179 ORIGIN("Origin"),
180 VARY("Vary");
181
182 private final String headerName;
183
184 Headers(final String headerName) {
185 this.headerName = headerName;
186 }
187
188 public String getName() {
189 return headerName;
190 }
191 }
192 }