001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.juneau.rest.util;
018
019import static org.apache.juneau.commons.utils.ThrowableUtils.*;
020import static org.apache.juneau.commons.utils.Utils.*;
021
022import java.io.*;
023
024import jakarta.servlet.*;
025
026/**
027 * ServletInputStream wrapper around a normal input stream with support for limiting input.
028 *
029 */
030public class BoundedServletInputStream extends ServletInputStream {
031
032   private final InputStream is;
033   private final ServletInputStream sis;
034   private long remain;
035
036   /**
037    * Wraps the specified byte array.
038    *
039    * @param b The byte contents of the stream.
040    */
041   public BoundedServletInputStream(byte[] b) {
042      this(new ByteArrayInputStream(b), Long.MAX_VALUE);
043   }
044
045   /**
046    * Wraps the specified input stream.
047    *
048    * @param is The input stream to wrap.
049    * @param max The maximum number of bytes to read from the stream.
050    */
051   public BoundedServletInputStream(InputStream is, long max) {
052      this.is = is;
053      this.sis = null;
054      this.remain = max;
055   }
056
057   /**
058    * Wraps the specified input stream.
059    *
060    * @param sis The input stream to wrap.
061    * @param max The maximum number of bytes to read from the stream.
062    */
063   public BoundedServletInputStream(ServletInputStream sis, long max) {
064      this.sis = sis;
065      this.is = sis;
066      this.remain = max;
067   }
068
069   @Override /* Overridden from InputStream */
070   public int available() throws IOException {
071      if (remain <= 0)
072         return 0;
073      return is.available();
074   }
075
076   @Override /* Overridden from InputStream */
077   public void close() throws IOException {
078      is.close();
079   }
080
081   @Override /* Overridden from ServletInputStream */
082   public boolean isFinished() { return sis == null ? false : sis.isFinished(); }
083
084   @Override /* Overridden from ServletInputStream */
085   public boolean isReady() { return sis == null ? true : sis.isReady(); }
086
087   @Override /* Overridden from InputStream */
088   public synchronized void mark(int limit) {
089      is.mark(limit);
090   }
091
092   @Override /* Overridden from InputStream */
093   public boolean markSupported() {
094      return is.markSupported();
095   }
096
097   @Override /* Overridden from InputStream */
098   public int read() throws IOException {
099      decrement();
100      return is.read();
101   }
102
103   @Override /* Overridden from InputStream */
104   public int read(byte[] b) throws IOException {
105      return read(b, 0, b.length);
106   }
107
108   @Override /* Overridden from InputStream */
109   public int read(byte[] b, int off, int len) throws IOException {
110      long numBytes = Math.min(len, remain);
111      int r = is.read(b, off, (int)numBytes);
112      if (r == -1)
113         return -1;
114      decrement(numBytes);
115      return r;
116   }
117
118   @Override /* Overridden from InputStream */
119   public synchronized void reset() throws IOException {
120      is.reset();
121   }
122
123   @Override /* Overridden from ServletInputStream */
124   public void setReadListener(ReadListener arg0) {
125      if (nn(sis))
126         sis.setReadListener(arg0);
127   }
128
129   @Override /* Overridden from InputStream */
130   public long skip(long n) throws IOException {
131      long toSkip = Math.min(n, remain);
132      long r = is.skip(toSkip);
133      decrement(r);
134      return r;
135   }
136
137   private void decrement() throws IOException {
138      remain--;
139      if (remain < 0)
140         throw ioex("Input limit exceeded.  See @Rest(maxInput).");
141   }
142
143   private void decrement(long count) throws IOException {
144      remain -= count;
145      if (remain < 0)
146         throw ioex("Input limit exceeded.  See @Rest(maxInput).");
147   }
148}