1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package info.magnolia.cms.security;
35
36 import static info.magnolia.cms.util.ExceptionUtil.*;
37 import static javax.servlet.http.HttpServletResponse.*;
38
39 import info.magnolia.cms.filters.OncePerRequestAbstractMgnlFilter;
40 import info.magnolia.cms.security.auth.callback.HttpClientCallback;
41
42 import java.io.IOException;
43 import java.util.ArrayList;
44 import java.util.List;
45
46 import javax.servlet.FilterChain;
47 import javax.servlet.ServletException;
48 import javax.servlet.http.HttpServletRequest;
49 import javax.servlet.http.HttpServletResponse;
50 import javax.servlet.http.HttpServletResponseWrapper;
51
52
53
54
55
56
57
58
59
60
61
62
63
64 public class SecurityCallbackFilter extends OncePerRequestAbstractMgnlFilter {
65
66
67
68
69 private final List<HttpClientCallback> clientCallbacks;
70
71 public SecurityCallbackFilter() {
72 this.clientCallbacks = new ArrayList<>();
73 }
74
75 @Override
76 public void doFilter(HttpServletRequest request, HttpServletResponse originalResponse, FilterChain chain) throws IOException, ServletException {
77 final SecurityCallbackFilter.StatusSniffingResponseWrapper response;
78 if (originalResponse instanceof SecurityCallbackFilter.StatusSniffingResponseWrapper) {
79 response = (SecurityCallbackFilter.StatusSniffingResponseWrapper) originalResponse;
80 } else {
81 response = new SecurityCallbackFilter.StatusSniffingResponseWrapper(originalResponse);
82 }
83 try {
84 chain.doFilter(request, response);
85 if (needsCallback(response)) {
86 selectAndHandleCallback(request, response);
87 }
88 } catch (Throwable e) {
89
90 if (wasCausedBy(e, javax.jcr.AccessDeniedException.class)) {
91 response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
92 selectAndHandleCallback(request, response);
93 } else {
94 rethrow(e, IOException.class, ServletException.class);
95 }
96 }
97 }
98
99 protected boolean needsCallback(StatusSniffingResponseWrapper response) {
100 final int status = response.getStatus();
101 return status == SC_FORBIDDEN || status == SC_UNAUTHORIZED;
102 }
103
104 protected void selectAndHandleCallback(HttpServletRequest request, StatusSniffingResponseWrapper response) {
105 selectClientCallback(request).handle(request, response);
106 }
107
108 protected HttpClientCallback selectClientCallback(HttpServletRequest request) {
109 for (HttpClientCallback clientCallback : clientCallbacks) {
110 if (clientCallback.accepts(request)) {
111 return clientCallback;
112 }
113 }
114 throw new IllegalStateException("No configured callback accepted this request " + request.toString());
115 }
116
117
118 public void addClientCallback(HttpClientCallback clientCallback) {
119 this.clientCallbacks.add(clientCallback);
120 }
121
122 public void setClientCallbacks(List<HttpClientCallback> clientCallbacks) {
123 this.clientCallbacks.addAll(clientCallbacks);
124 }
125
126
127 public List<HttpClientCallback> getClientCallbacks() {
128 return clientCallbacks;
129 }
130
131
132
133
134
135
136
137
138 public static class StatusSniffingResponseWrapper extends HttpServletResponseWrapper {
139 private int status = SC_OK;
140
141 public StatusSniffingResponseWrapper(HttpServletResponse response) {
142 super(response);
143 }
144
145 public int getStatus() {
146 return status;
147 }
148
149 @Override
150 public void reset() {
151 super.reset();
152 status = SC_OK;
153 }
154
155 @Override
156 public void setStatus(int sc) {
157 super.setStatus(sc);
158 this.status = sc;
159 }
160
161 @Override
162 public void setStatus(int sc, String sm) {
163 super.setStatus(sc, sm);
164 this.status = sc;
165 }
166
167 @Override
168 public void sendRedirect(String location) throws IOException {
169 super.sendRedirect(location);
170 this.status = SC_MOVED_TEMPORARILY;
171 }
172
173 @Override
174 public void sendError(int sc) throws IOException {
175 super.sendError(sc);
176 this.status = sc;
177 }
178
179 @Override
180 public void sendError(int sc, String msg) throws IOException {
181 super.sendError(sc, msg);
182 this.status = sc;
183 }
184 }
185 }