1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 package info.magnolia.module.cache.filter;
35
36 import info.magnolia.cms.core.Path;
37 import info.magnolia.cms.util.RequestHeaderUtil;
38
39 import java.io.ByteArrayInputStream;
40 import java.io.ByteArrayOutputStream;
41 import java.io.File;
42 import java.io.FileInputStream;
43 import java.io.FileOutputStream;
44 import java.io.IOException;
45 import java.io.OutputStream;
46 import java.io.OutputStreamWriter;
47 import java.io.PrintWriter;
48 import java.util.Collection;
49 import java.util.Iterator;
50
51 import javax.servlet.ServletOutputStream;
52 import javax.servlet.http.HttpServletResponse;
53 import javax.servlet.http.HttpServletResponseWrapper;
54
55 import org.apache.commons.collections.MultiMap;
56 import org.apache.commons.collections.map.MultiValueMap;
57 import org.apache.commons.httpclient.util.DateParseException;
58 import org.apache.commons.httpclient.util.DateUtil;
59 import org.apache.commons.io.FileUtils;
60 import org.apache.commons.io.IOUtils;
61 import org.apache.commons.io.output.ThresholdingOutputStream;
62 import org.slf4j.Logger;
63 import org.slf4j.LoggerFactory;
64
65
66
67
68
69
70
71
72
73 public class CacheResponseWrapper extends HttpServletResponseWrapper {
74
75 public static final int DEFAULT_THRESHOLD = 500 * 1024;
76 private static final Logger log = LoggerFactory.getLogger(CacheResponseWrapper.class);
77
78 private ServletOutputStream wrappedStream;
79 private PrintWriter wrappedWriter = null;
80 private final MultiMap headers = new MultiValueMap();
81 private int status = SC_OK;
82 private boolean isError;
83 private String redirectionLocation;
84 private final HttpServletResponse originalResponse;
85 private File contentFile;
86 private long contentLength = -1;
87 private ResponseExpirationCalculator responseExpirationCalculator;
88
89 private final AbstractThresholdingCacheOutputStream thresholdingOutputStream;
90 private final boolean serveIfThresholdReached;
91
92 private String errorMsg;
93
94 public CacheResponseWrapper(final HttpServletResponse response, int threshold, boolean serveIfThresholdReached) {
95 this(response, threshold, serveIfThresholdReached, null);
96
97 }
98
99 public CacheResponseWrapper(final HttpServletResponse response, int threshold, boolean serveIfThresholdReached, AbstractThresholdingCacheOutputStream stream) {
100 super(response);
101 this.serveIfThresholdReached = serveIfThresholdReached;
102 this.originalResponse = response;
103 if (stream == null) {
104 this.thresholdingOutputStream = new ThresholdingCacheOutputStream(threshold);
105 } else {
106 this.thresholdingOutputStream = stream;
107 }
108 this.wrappedStream = new SimpleServletOutputStream(thresholdingOutputStream);
109 }
110
111 public boolean isThresholdExceeded() {
112 return thresholdingOutputStream.isThresholdExceeded();
113 }
114
115 public byte[] getBufferedContent() {
116 if (this.thresholdingOutputStream.getInMemoryBuffer() instanceof ByteArrayOutputStream) {
117 return ((ByteArrayOutputStream) this.thresholdingOutputStream.getInMemoryBuffer()).toByteArray();
118 }
119 return new byte[] {};
120 }
121
122 public File getContentFile() {
123 return contentFile;
124 }
125
126
127 @Override
128 public ServletOutputStream getOutputStream() throws IOException {
129 return wrappedStream;
130 }
131
132 public ThresholdingOutputStream getThresholdingOutputStream() throws IOException {
133 return thresholdingOutputStream;
134 }
135
136 @Override
137 public PrintWriter getWriter() throws IOException {
138 if (wrappedWriter == null) {
139 String encoding = getCharacterEncoding();
140 wrappedWriter = encoding != null
141 ? new PrintWriter(new OutputStreamWriter(getOutputStream(), encoding))
142 : new PrintWriter(new OutputStreamWriter(getOutputStream()));
143 }
144
145 return wrappedWriter;
146 }
147
148 @Override
149 public void flushBuffer() throws IOException {
150 flush();
151 }
152
153 public void flush() throws IOException {
154 wrappedStream.flush();
155
156 if (wrappedWriter != null) {
157 wrappedWriter.flush();
158 }
159 }
160
161 @Override
162 public void reset() {
163 super.reset();
164
165 wrappedWriter = null;
166 status = SC_OK;
167 headers.clear();
168
169 cleanUp();
170 }
171
172 @Override
173 public void resetBuffer() {
174 super.resetBuffer();
175 wrappedWriter = null;
176 cleanUp();
177 }
178
179 public void cleanUp() {
180 if (contentFile != null && contentFile.exists()) {
181 if (!contentFile.delete()) {
182 log.error("Can't delete file: " + contentFile);
183 }
184 }
185 contentFile = null;
186 }
187
188 public int getStatus() {
189 return status;
190 }
191
192 public boolean isError() {
193 return isError;
194 }
195
196 public MultiMap getHeaders() {
197 return headers;
198 }
199
200 public long getLastModified() {
201
202
203 final Collection values = (Collection) headers.get("Last-Modified");
204 if (values == null || values.size() != 1) {
205 throw new IllegalStateException("Can't get Last-Modified header : no or multiple values : " + values);
206 }
207 final Object value = values.iterator().next();
208 if (value instanceof String) {
209 return parseStringDate((String) value);
210 } else if (value instanceof Long) {
211 return ((Long) value).longValue();
212 } else {
213 throw new IllegalStateException("Can't get Last-Modified header : " + value);
214 }
215 }
216
217 private long parseStringDate(String value) {
218 try {
219 return DateUtil.parseDate(value).getTime();
220 } catch (DateParseException e) {
221 throw new IllegalStateException("Could not parse Last-Modified header with value " + value + " : " + e.getMessage());
222 }
223 }
224
225
226
227
228
229
230 public void setResponseExpirationDetectionEnabled() {
231 this.responseExpirationCalculator = new ResponseExpirationCalculator();
232 }
233
234
235
236
237
238
239
240
241 public int getTimeToLiveInSeconds() {
242 return responseExpirationCalculator != null ? responseExpirationCalculator.getMaxAgeInSeconds() : -1;
243 }
244
245 public String getRedirectionLocation() {
246 return redirectionLocation;
247 }
248
249 @Override
250 public void setDateHeader(String name, long date) {
251 replaceHeader(name, Long.valueOf(date));
252 }
253
254 @Override
255 public void addDateHeader(String name, long date) {
256 appendHeader(name, Long.valueOf(date));
257 }
258
259 @Override
260 public void setHeader(String name, String value) {
261 replaceHeader(name, value);
262 }
263
264 @Override
265 public void addHeader(String name, String value) {
266 appendHeader(name, value);
267 }
268
269 @Override
270 public void setIntHeader(String name, int value) {
271 replaceHeader(name, Integer.valueOf(value));
272 }
273
274 @Override
275 public void addIntHeader(String name, int value) {
276 appendHeader(name, Integer.valueOf(value));
277 }
278
279 @Override
280 public boolean containsHeader(String name) {
281 return headers.containsKey(name);
282 }
283
284 private void replaceHeader(String name, Object value) {
285 if (responseExpirationCalculator == null || !responseExpirationCalculator.addHeader(name, value)) {
286 headers.remove(name);
287 headers.put(name, value);
288 }
289 }
290
291 private void appendHeader(String name, Object value) {
292 if (responseExpirationCalculator == null || !responseExpirationCalculator.addHeader(name, value)) {
293 headers.put(name, value);
294 }
295 }
296
297 @Override
298 public void setStatus(int status) {
299 this.status = status;
300 }
301
302 @Override
303 public void setStatus(int status, String string) {
304 this.status = status;
305 }
306
307 @Override
308 public void sendRedirect(String location) throws IOException {
309 this.status = SC_MOVED_TEMPORARILY;
310 this.redirectionLocation = location;
311 }
312
313 @Override
314 public void sendError(int status, String errorMsg) throws IOException {
315 this.errorMsg = errorMsg;
316 this.status = status;
317 this.isError = true;
318 }
319
320 @Override
321 public void sendError(int status) throws IOException {
322 this.status = status;
323 this.isError = true;
324 }
325
326 @Override
327 public void setContentLength(int len) {
328 this.contentLength = len;
329 }
330
331 public int getContentLength() {
332 return (int) (contentLength >= 0 ? contentLength : thresholdingOutputStream.getByteCount());
333 }
334
335 public void replay(HttpServletResponse target) throws IOException {
336 replayHeadersAndStatus(target);
337 replayContent(target, true);
338 }
339
340 public void replayHeadersAndStatus(HttpServletResponse target) throws IOException {
341 if (isError) {
342 if (errorMsg != null) {
343 target.sendError(status, errorMsg);
344 }
345 else {
346 target.sendError(status);
347 }
348 }
349 else if (redirectionLocation != null) {
350 target.sendRedirect(redirectionLocation);
351 }
352 else {
353 target.setStatus(status);
354 }
355
356 target.setStatus(getStatus());
357
358 final Iterator it = headers.keySet().iterator();
359 while (it.hasNext()) {
360 final String header = (String) it.next();
361
362 final Collection values = (Collection) headers.get(header);
363 final Iterator valIt = values.iterator();
364 while (valIt.hasNext()) {
365 final Object val = valIt.next();
366 RequestHeaderUtil.setHeader(target, header, val);
367 }
368 }
369
370
371 target.setContentType(getContentType());
372 target.setCharacterEncoding(getCharacterEncoding());
373 }
374
375 public void replayContent(HttpServletResponse target, boolean setContentLength) throws IOException {
376 if (setContentLength) {
377 target.setContentLength(getContentLength());
378 }
379 if (getContentLength() > 0) {
380 if (isThresholdExceeded()) {
381 FileInputStream in = FileUtils.openInputStream(getContentFile());
382 IOUtils.copy(in, target.getOutputStream());
383 IOUtils.closeQuietly(in);
384 }
385 else {
386 IOUtils.copy(new ByteArrayInputStream(((ByteArrayOutputStream) this.thresholdingOutputStream.getInMemoryBuffer()).toByteArray()), target.getOutputStream());
387 }
388 target.flushBuffer();
389 }
390 }
391
392 protected OutputStream thresholdReached(OutputStream out) throws IOException {
393
394 if (serveIfThresholdReached) {
395 replayHeadersAndStatus(originalResponse);
396 out = originalResponse.getOutputStream();
397 log.debug("Reached threshold for in-memory caching. Will not cache and stream response directly to user.");
398 }
399 else {
400 contentFile = File.createTempFile("cacheStream", null, Path.getTempDirectory());
401 if (contentFile != null) {
402 log.debug("Reached threshold for in-memory caching. Will continue caching in new cache temp file {}", contentFile.getAbsolutePath());
403 contentFile.deleteOnExit();
404 out = new FileOutputStream(contentFile);
405 } else {
406 log.error("Reached threshold for in-memory caching, but unable to create the new cache temp file. Will not cache and stream response directly to user.");
407 replayHeadersAndStatus(originalResponse);
408 out = originalResponse.getOutputStream();
409 }
410 }
411 out.write(getBufferedContent());
412 out.flush();
413 return out;
414 }
415
416 private final class ThresholdingCacheOutputStream extends AbstractThresholdingCacheOutputStream {
417
418 private ThresholdingCacheOutputStream(int threshold) {
419 super(threshold);
420 }
421
422 @Override
423 protected OutputStream getStream() throws IOException {
424 return out;
425 }
426
427 @Override
428 protected void thresholdReached() throws IOException {
429 out = CacheResponseWrapper.this.thresholdReached(out);
430 }
431 }
432 }