001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.juneau.rest.util;
018
019import java.io.*;
020
021import jakarta.servlet.*;
022
023/**
024 * ServletInputStream wrapper around a normal input stream with support for limiting input.
025 *
026 * <h5 class='section'>See Also:</h5><ul>
027 * </ul>
028 */
029public class BoundedServletInputStream extends ServletInputStream {
030
031   private final InputStream is;
032   private final ServletInputStream sis;
033   private long remain;
034
035   /**
036    * Wraps the specified input stream.
037    *
038    * @param is The input stream to wrap.
039    * @param max The maximum number of bytes to read from the stream.
040    */
041   public BoundedServletInputStream(InputStream is, long max) {
042      this.is = is;
043      this.sis = null;
044      this.remain = max;
045   }
046
047   /**
048    * Wraps the specified input stream.
049    *
050    * @param sis The input stream to wrap.
051    * @param max The maximum number of bytes to read from the stream.
052    */
053   public BoundedServletInputStream(ServletInputStream sis, long max) {
054      this.sis = sis;
055      this.is = sis;
056      this.remain = max;
057   }
058
059   /**
060    * Wraps the specified byte array.
061    *
062    * @param b The byte contents of the stream.
063    */
064   public BoundedServletInputStream(byte[] b) {
065      this(new ByteArrayInputStream(b), Long.MAX_VALUE);
066   }
067
068   @Override /* InputStream */
069   public int read() throws IOException {
070      decrement();
071      return is.read();
072   }
073
074   @Override /* InputStream */
075   public int read(byte[] b) throws IOException {
076      return read(b, 0, b.length);
077   }
078
079   @Override /* InputStream */
080   public int read(final byte[] b, final int off, final int len) throws IOException {
081      long numBytes = Math.min(len, remain);
082      int r = is.read(b, off, (int) numBytes);
083      if (r == -1)
084         return -1;
085      decrement(numBytes);
086      return r;
087   }
088
089   @Override /* InputStream */
090   public long skip(final long n) throws IOException {
091      long toSkip = Math.min(n, remain);
092      long r = is.skip(toSkip);
093      decrement(r);
094      return r;
095   }
096
097   @Override /* InputStream */
098   public int available() throws IOException {
099      if (remain <= 0)
100         return 0;
101      return is.available();
102   }
103
104   @Override /* InputStream */
105   public synchronized void reset() throws IOException {
106      is.reset();
107   }
108
109   @Override /* InputStream */
110   public synchronized void mark(int limit) {
111      is.mark(limit);
112   }
113
114   @Override /* InputStream */
115   public boolean markSupported() {
116      return is.markSupported();
117   }
118
119   @Override /* InputStream */
120   public void close() throws IOException {
121      is.close();
122   }
123
124   @Override /* ServletInputStream */
125   public boolean isFinished() {
126      return sis == null ? false : sis.isFinished();
127   }
128
129   @Override /* ServletInputStream */
130   public boolean isReady() {
131      return sis == null ? true : sis.isReady();
132   }
133
134   @Override /* ServletInputStream */
135   public void setReadListener(ReadListener arg0) {
136      if (sis != null)
137         sis.setReadListener(arg0);
138   }
139
140   private void decrement() throws IOException {
141      remain--;
142      if (remain < 0)
143         throw new IOException("Input limit exceeded.  See @Rest(maxInput).");
144   }
145
146   private void decrement(long count) throws IOException {
147      remain -= count;
148      if (remain < 0)
149         throw new IOException("Input limit exceeded.  See @Rest(maxInput).");
150   }
151}