package jmri.util.web;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.Locale;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import jmri.web.servlet.ServletUtil;
/**
* Test utility that simulates a servlet request/response exchange using
* Mockito-backed {@link HttpServletRequest} and {@link HttpServletResponse}.
*
*
Designed for unit testing servlet logic without a servlet container.
* Captures response status, headers, redirects, content type, and body.
* This is not a full servlet container simulation; only commonly used
* servlet behaviours are implemented.
* Spring Framework moved from javax.servlet to jakarta.servlet in version 6,
* hence unable to use their Mock test classes from 5.3.39.
*/
public class MockServletExchange {
public static final String DELETE = "DELETE";
public static final String GET = "GET";
public static final String POST = "POST";
public static final String PUT = "PUT";
private final HttpServletRequest request = mock(HttpServletRequest.class);
private final HttpServletResponse response = mock(HttpServletResponse.class);
private final HttpSession session = mock(HttpSession.class);
private final ServletConfig config = mock(ServletConfig.class);
private final ServletContext context = mock(ServletContext.class);
private final ByteArrayOutputStream responseOutputStream = new ByteArrayOutputStream();
private final PrintWriter responseWriter =
new PrintWriter(new OutputStreamWriter(responseOutputStream, StandardCharsets.UTF_8), true);
private final Map> headers = new HashMap<>();
private final Map parameterMap = new HashMap<>();
private final Map attributes = new HashMap<>();
private int status = HttpServletResponse.SC_OK;
private String responseContentType;
private String redirectedUrl;
/**
* Creates a new mocked servlet exchange.
*
* @param method the HTTP method (e.g. "GET", "POST").
* @param requestUri the request URI.
*/
public MockServletExchange(String method, String requestUri) {
when(config.getServletContext()).thenReturn(context);
when(request.getMethod()).thenReturn(method);
when(request.getRequestURI()).thenReturn(requestUri);
when(request.getLocale()).thenReturn(Locale.ENGLISH);
when(request.getSession()).thenReturn(session);
when(request.getCharacterEncoding()).thenReturn(ServletUtil.UTF8);
when(request.getParameterMap())
.thenAnswer(inv -> Collections.unmodifiableMap(parameterMap));
when(request.getParameter(anyString())).thenAnswer(inv -> {
String[] values = parameterMap.get(inv.getArgument(0, String.class));
return (values != null && values.length > 0) ? values[0] : null;
});
when(request.getAttribute(anyString()))
.thenAnswer(inv -> attributes.get(inv.getArgument(0, String.class)));
assertDoesNotThrow( () ->
when(response.getWriter()).thenReturn(responseWriter));
assertDoesNotThrow( () ->
when(response.getOutputStream()).thenReturn(new ServletOutputStream() {
@Override
public void write(int b) {
responseOutputStream.write(b);
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setWriteListener(javax.servlet.WriteListener l) {
}
}));
// Redirect capture
assertDoesNotThrow( () ->
doAnswer(inv -> {
String url = inv.getArgument(0, String.class);
if (url == null) {
throw new IllegalArgumentException("Redirect URL must not be null");
}
this.redirectedUrl = url;
this.status = HttpServletResponse.SC_MOVED_TEMPORARILY;
return null;
}).when(response).sendRedirect(any())); // any passes null instances through
// case-insensitive check
doAnswer(inv -> {
String name = inv.getArgument(0, String.class);
String value = inv.getArgument(1, String.class);
// Standardise key to lowercase for storage
headers.put(name.toLowerCase(), new ArrayList<>(List.of(value)));
return null;
}).when(response).setHeader(anyString(), anyString());
// Fix getHeader to return from the lowercase map
when(response.getHeader(anyString())).thenAnswer(inv -> {
String name = inv.getArgument(0, String.class);
List values = headers.get(name.toLowerCase());
return (values != null && !values.isEmpty()) ? values.get(0) : null;
});
// addHeader appends
doAnswer(inv -> {
String name = inv.getArgument(0, String.class);
String value = inv.getArgument(1, String.class);
headers.computeIfAbsent(name, k -> new ArrayList<>()).add(value);
return null;
}).when(response).addHeader(anyString(), anyString());
doAnswer(inv -> {
String name = inv.getArgument(0, String.class);
long timestamp = inv.getArgument(1, Long.class);
// Use lowercase keys to remain consistent with your setHeader mock
headers.put(name.toLowerCase(), new ArrayList<>(List.of(getRfc7232formatHttpDate(timestamp))));
return null;
}).when(response).setDateHeader(anyString(), anyLong());
// addDateHeader appends
doAnswer(inv -> {
String name = inv.getArgument(0, String.class);
long timestamp = inv.getArgument(1, Long.class);
headers.computeIfAbsent(name.toLowerCase(), k -> new ArrayList<>())
.add(getRfc7232formatHttpDate(timestamp));
return null;
}).when(response).addDateHeader(anyString(), anyLong());
when(response.getHeaderNames())
.thenAnswer(inv -> Collections.unmodifiableSet(headers.keySet()));
when(response.getHeaders(anyString()))
.thenAnswer(inv -> {
List values = headers.get(inv.getArgument(0, String.class));
return values == null
? Collections.emptyList()
: Collections.unmodifiableList(values);
});
doAnswer(inv -> {
this.status = inv.getArgument(0, Integer.class);
return null;
}).when(response).setStatus(anyInt());
doAnswer(inv -> {
this.responseContentType = inv.getArgument(0, String.class);
return null;
}).when(response).setContentType(anyString());
doAnswer(inv -> {
this.attributes.put(inv.getArgument(0), inv.getArgument(1));
return null;
}).when(request).setAttribute(anyString(), any());
withBody(""); // default empty
}
public final MockServletExchange withBody(String body) {
assertDoesNotThrow( () ->
when(request.getReader())
.thenReturn(new BufferedReader(
new StringReader(body != null ? body : ""))));
return this;
}
public MockServletExchange withParameter(String key, String value) {
parameterMap.put(key, new String[]{value});
return this;
}
public MockServletExchange withAttribute(String key, Object value) {
attributes.put(key, value);
return this;
}
public MockServletExchange withRequestContentType(String type) {
when(request.getContentType()).thenReturn(type);
return this;
}
public MockServletExchange withContextPath(String path) {
when(request.getContextPath()).thenReturn(path);
return this;
}
public MockServletExchange withPathInfo(String info) {
when(request.getPathInfo()).thenReturn(info);
return this;
}
public HttpServletRequest getRequest() {
return request;
}
public HttpServletResponse getResponse() {
return response;
}
public HttpSession getSession() {
return session;
}
public ServletConfig getConfig() {
return config;
}
public int getResponseStatus() {
return status;
}
public String getResponseContentType() {
return responseContentType;
}
public String getRedirectedUrl() {
return redirectedUrl;
}
public String getResponseContentAsString() {
responseWriter.flush();
return responseOutputStream.toString(StandardCharsets.UTF_8);
}
public static String getRfc7232formatHttpDate(long timestamp) {
return DateTimeFormatter.RFC_1123_DATE_TIME
.withZone(ZoneOffset.UTC)
.format(Instant.ofEpochMilli(timestamp));
}
}