1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package info.magnolia.cms.security;
35
36 import info.magnolia.cms.filters.OncePerRequestAbstractMgnlFilter;
37 import info.magnolia.cms.security.auth.callback.HttpClientCallback;
38
39 import java.io.IOException;
40 import java.util.ArrayList;
41 import java.util.List;
42
43 import javax.servlet.FilterChain;
44 import javax.servlet.ServletException;
45 import javax.servlet.http.HttpServletRequest;
46 import javax.servlet.http.HttpServletResponse;
47 import javax.servlet.http.HttpServletResponseWrapper;
48
49
50 import static info.magnolia.cms.util.ExceptionUtil.rethrow;
51 import static info.magnolia.cms.util.ExceptionUtil.wasCausedBy;
52 import static javax.servlet.http.HttpServletResponse.SC_FORBIDDEN;
53 import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED;
54
55
56
57
58
59
60
61
62
63
64
65
66
67 public class SecurityCallbackFilter extends OncePerRequestAbstractMgnlFilter {
68 private static final org.slf4j.Logger log = org.slf4j.LoggerFactory.getLogger(SecurityCallbackFilter.class);
69
70
71
72
73 private final List<HttpClientCallback> clientCallbacks;
74
75 public SecurityCallbackFilter() {
76 this.clientCallbacks = new ArrayList<HttpClientCallback>();
77 }
78
79 @Override
80 public void doFilter(HttpServletRequest request, HttpServletResponse originalResponse, FilterChain chain) throws IOException, ServletException {
81 final StatusSniffingResponseWrapper response = new StatusSniffingResponseWrapper(originalResponse);
82 try {
83 chain.doFilter(request, response);
84 if (needsCallback(response)) {
85 selectAndHandleCallback(request, response);
86 }
87 } catch (Throwable e) {
88
89 if (wasCausedBy(e, javax.jcr.AccessDeniedException.class)) {
90 response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
91 selectAndHandleCallback(request, response);
92 } else {
93 rethrow(e, IOException.class, ServletException.class);
94 }
95 }
96 }
97
98 protected boolean needsCallback(StatusSniffingResponseWrapper response) {
99 final int status = response.getStatus();
100 return status == SC_FORBIDDEN || status == SC_UNAUTHORIZED;
101 }
102
103 protected void selectAndHandleCallback(HttpServletRequest request, StatusSniffingResponseWrapper response) {
104 selectClientCallback(request).handle(request, response);
105 }
106
107 protected HttpClientCallback selectClientCallback(HttpServletRequest request) {
108 for (HttpClientCallback clientCallback : clientCallbacks) {
109 if (clientCallback.accepts(request)) {
110 return clientCallback;
111 }
112 }
113 throw new IllegalStateException("No configured callback accepted this request " + request.toString());
114 }
115
116
117 public void addClientCallback(HttpClientCallback clientCallback) {
118 this.clientCallbacks.add(clientCallback);
119 }
120
121 public void setClientCallbacks(List<HttpClientCallback> clientCallbacks) {
122 this.clientCallbacks.addAll(clientCallbacks);
123 }
124
125
126 public List<HttpClientCallback> getClientCallbacks() {
127 return clientCallbacks;
128 }
129
130
131
132
133
134
135
136
137 public static class StatusSniffingResponseWrapper extends HttpServletResponseWrapper {
138 private int status = SC_OK;
139
140 public StatusSniffingResponseWrapper(HttpServletResponse response) {
141 super(response);
142 }
143
144 public int getStatus() {
145 return status;
146 }
147
148 @Override
149 public void reset() {
150 super.reset();
151 status = SC_OK;
152 }
153
154 @Override
155 public void setStatus(int sc) {
156 super.setStatus(sc);
157 this.status = sc;
158 }
159
160 @Override
161 public void setStatus(int sc, String sm) {
162 super.setStatus(sc, sm);
163 this.status = sc;
164 }
165
166 @Override
167 public void sendRedirect(String location) throws IOException {
168 super.sendRedirect(location);
169 this.status = SC_MOVED_TEMPORARILY;
170 }
171
172 @Override
173 public void sendError(int sc) throws IOException {
174 super.sendError(sc);
175 this.status = sc;
176 }
177
178 @Override
179 public void sendError(int sc, String msg) throws IOException {
180 super.sendError(sc, msg);
181 this.status = sc;
182 }
183 }
184 }