001// ***************************************************************************************************************************
002// * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements.  See the NOTICE file *
003// * distributed with this work for additional information regarding copyright ownership.  The ASF licenses this file        *
004// * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance            *
005// * with the License.  You may obtain a copy of the License at                                                              *
006// *                                                                                                                         *
007// *  http://www.apache.org/licenses/LICENSE-2.0                                                                             *
008// *                                                                                                                         *
009// * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an  *
010// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the        *
011// * specific language governing permissions and limitations under the License.                                              *
012// ***************************************************************************************************************************
013package org.apache.juneau.rest.util;
014
015import java.io.*;
016
017import javax.servlet.*;
018
019/**
020 * ServletInputStream wrapper around a normal input stream with support for limiting input.
021 */
022public final class BoundedServletInputStream extends ServletInputStream {
023
024   private final InputStream is;
025   private final ServletInputStream sis;
026   private long remain;
027
028   /**
029    * Wraps the specified input stream.
030    *
031    * @param is The input stream to wrap.
032    * @param max The maximum number of bytes to read from the stream.
033    */
034   public BoundedServletInputStream(InputStream is, long max) {
035      this.is = is;
036      this.sis = null;
037      this.remain = max;
038   }
039
040   /**
041    * Wraps the specified input stream.
042    *
043    * @param sis The input stream to wrap.
044    * @param max The maximum number of bytes to read from the stream.
045    */
046   public BoundedServletInputStream(ServletInputStream sis, long max) {
047      this.sis = sis;
048      this.is = sis;
049      this.remain = max;
050   }
051
052   /**
053    * Wraps the specified byte array.
054    *
055    * @param b The byte contents of the stream.
056    */
057   public BoundedServletInputStream(byte[] b) {
058      this(new ByteArrayInputStream(b), Long.MAX_VALUE);
059   }
060
061   @Override /* InputStream */
062   public final int read() throws IOException {
063      decrement();
064      return is.read();
065   }
066
067   @Override /* InputStream */
068   public int read(byte[] b) throws IOException {
069      return read(b, 0, b.length);
070   }
071
072   @Override /* InputStream */
073   public int read(final byte[] b, final int off, final int len) throws IOException {
074      long numBytes = Math.min(len, remain);
075      int r = is.read(b, off, (int) numBytes);
076      if (r == -1)
077         return -1;
078      decrement(numBytes);
079      return r;
080   }
081
082   @Override /* InputStream */
083   public long skip(final long n) throws IOException {
084      long toSkip = Math.min(n, remain);
085      long r = is.skip(toSkip);
086      decrement(r);
087      return r;
088   }
089
090   @Override /* InputStream */
091   public int available() throws IOException {
092      if (remain <= 0)
093         return 0;
094      return is.available();
095   }
096
097   @Override /* InputStream */
098   public synchronized void reset() throws IOException {
099      is.reset();
100   }
101
102   @Override /* InputStream */
103   public synchronized void mark(int limit) {
104      is.mark(limit);
105   }
106
107   @Override /* InputStream */
108   public boolean markSupported() {
109      return is.markSupported();
110   }
111
112   @Override /* InputStream */
113   public final void close() throws IOException {
114      is.close();
115   }
116
117   @Override /* ServletInputStream */
118   public boolean isFinished() {
119      return sis == null ? false : sis.isFinished();
120   }
121
122   @Override /* ServletInputStream */
123   public boolean isReady() {
124      return sis == null ? true : sis.isReady();
125   }
126
127   @Override /* ServletInputStream */
128   public void setReadListener(ReadListener arg0) {
129      if (sis != null)
130         sis.setReadListener(arg0);
131   }
132
133   private void decrement() throws IOException {
134      remain--;
135      if (remain < 0)
136         throw new IOException("Input limit exceeded.  See @RestResource(maxInput).");
137   }
138
139   private void decrement(long count) throws IOException {
140      remain -= count;
141      if (remain < 0)
142         throw new IOException("Input limit exceeded.  See @RestResource(maxInput).");
143   }
144}