Skip to content
Snippets Groups Projects
Commit 36f445c9 authored by Marius Dumitru Florea's avatar Marius Dumitru Florea
Browse files

XWIKI-22716: The WebSocket context assumes the handshake request and response...

XWIKI-22716: The WebSocket context assumes the handshake request and response objects can be used after the handshake is performed

(cherry picked from commit 07336af7)
parent 5b34bd0d
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@
package org.xwiki.websocket.internal;
import java.net.HttpCookie;
import java.net.URI;
import java.security.Principal;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
......@@ -36,6 +37,9 @@
import javax.servlet.http.HttpSession;
import javax.websocket.server.HandshakeRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.xpn.xwiki.web.XWikiRequest;
import com.xpn.xwiki.web.XWikiServletRequestStub;
......@@ -47,8 +51,16 @@
*/
public class XWikiWebSocketRequestStub extends XWikiServletRequestStub
{
private static final Logger LOGGER = LoggerFactory.getLogger(XWikiWebSocketRequestStub.class);
private final HandshakeRequest request;
private final URI requestURI;
private final String queryString;
private final Principal userPrincipal;
/**
* Creates a new XWiki request that wraps the given WebSocket handshake request.
*
......@@ -59,6 +71,9 @@ public XWikiWebSocketRequestStub(HandshakeRequest request)
super(buildFromHandshakeRequest(request));
this.request = request;
this.requestURI = request.getRequestURI();
this.queryString = request.getQueryString();
this.userPrincipal = request.getUserPrincipal();
}
private static Builder buildFromHandshakeRequest(HandshakeRequest request)
......@@ -67,6 +82,7 @@ private static Builder buildFromHandshakeRequest(HandshakeRequest request)
Optional<String> cookieHeader = headers.getOrDefault("Cookie", Collections.emptyList()).stream().findFirst();
return new Builder().setRequestParameters(adaptParameterMap(request.getParameterMap()))
.setCookies(parseCookies(cookieHeader)).setHeaders(headers)
.setHttpSession((HttpSession) request.getHttpSession())
// The WebSocket API (JSR-356) doesn't expose the client IP address but at least we can avoid a null pointer
// exception.
.setRemoteAddr("");
......@@ -111,7 +127,7 @@ public String getMethod()
@Override
public String getRequestURI()
{
return this.request.getRequestURI().toString();
return this.requestURI.toString();
}
private static Map<String, String[]> adaptParameterMap(Map<String, List<String>> params)
......@@ -123,18 +139,6 @@ private static Map<String, String[]> adaptParameterMap(Map<String, List<String>>
return parameters;
}
@Override
public HttpSession getSession()
{
return getSession(true);
}
@Override
public HttpSession getSession(boolean create)
{
return (HttpSession) this.request.getHttpSession();
}
@Override
public String getServletPath()
{
......@@ -144,30 +148,37 @@ public String getServletPath()
@Override
public String getPathInfo()
{
return this.request.getRequestURI().getPath();
return this.requestURI.getPath();
}
@Override
public String getScheme()
{
return this.request.getRequestURI().getScheme();
return this.requestURI.getScheme();
}
@Override
public String getQueryString()
{
return this.request.getQueryString();
return this.queryString;
}
@Override
public Principal getUserPrincipal()
{
return this.request.getUserPrincipal();
return this.userPrincipal;
}
@Override
public boolean isUserInRole(String role)
{
return this.request.isUserInRole(role);
try {
return this.request.isUserInRole(role);
} catch (Exception e) {
LOGGER.debug("Failed to determine if the currently authenticated user has the specified role. "
+ "This can happen if this method is called outside the WebSocket handshake request, "
+ "i.e. from a WebSocket end-point.", e);
return false;
}
}
}
......@@ -25,10 +25,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
......@@ -36,6 +37,8 @@
import javax.websocket.HandshakeResponse;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.xpn.xwiki.web.XWikiRequest;
import com.xpn.xwiki.web.XWikiResponse;
......@@ -49,6 +52,8 @@
*/
public class XWikiWebSocketResponseStub extends XWikiServletResponseStub
{
private static final Logger LOGGER = LoggerFactory.getLogger(XWikiWebSocketResponseStub.class);
private final HandshakeResponse response;
/**
......@@ -64,58 +69,67 @@ public XWikiWebSocketResponseStub(HandshakeResponse response)
@Override
public void addHeader(String name, String value)
{
List<String> values = getHeaderValues(name);
if (values == null) {
values = new ArrayList<>();
this.response.getHeaders().put(name, values);
}
List<String> values = getHeaderValues(name).orElseGet(() -> {
List<String> emptyValues = new ArrayList<>();
getHeaders().put(name, emptyValues);
return emptyValues;
});
values.add(value);
}
@Override
public boolean containsHeader(String name)
{
List<String> values = getHeaderValues(name);
return values != null && !values.isEmpty();
return getHeaderValues(name).map(values -> !values.isEmpty()).orElse(false);
}
@Override
public String getHeader(String name)
{
List<String> values = getHeaderValues(name);
return values != null && !values.isEmpty() ? values.get(0) : null;
return getHeaderValues(name).map(values -> values.isEmpty() ? null : values.get(0)).orElse(null);
}
private List<String> getHeaderValues(String name)
private Optional<List<String>> getHeaderValues(String name)
{
for (Map.Entry<String, List<String>> entry : this.response.getHeaders().entrySet()) {
for (Map.Entry<String, List<String>> entry : getHeaders().entrySet()) {
if (StringUtils.equalsIgnoreCase(name, entry.getKey())) {
return entry.getValue();
return Optional.of(entry.getValue());
}
}
return null;
return Optional.empty();
}
@Override
public Collection<String> getHeaders(String name)
{
List<String> values = getHeaderValues(name);
return values != null ? new ArrayList<>(values) : Collections.emptyList();
return getHeaderValues(name).map(ArrayList::new).orElseGet(ArrayList::new);
}
@Override
public Collection<String> getHeaderNames()
{
return new LinkedHashSet<>(this.response.getHeaders().keySet());
return new LinkedHashSet<>(getHeaders().keySet());
}
private Map<String, List<String>> getHeaders()
{
try {
return this.response.getHeaders();
} catch (Exception e) {
LOGGER.debug("Failed to retrieve the WebSocket handshake response headers. "
+ "This can happen if the HandshakeResponse object is used after the handshake is performed, "
+ "e.g. in the WebSocket end-point.", e);
return new HashMap<>();
}
}
@Override
public void setHeader(String name, String value)
{
Set<String> namesToRemove = this.response.getHeaders().keySet().stream()
Set<String> namesToRemove = getHeaders().keySet().stream()
.filter(headerName -> StringUtils.equalsIgnoreCase(name, headerName)).collect(Collectors.toSet());
this.response.getHeaders().keySet().removeAll(namesToRemove);
this.response.getHeaders().put(name, new ArrayList<>(Arrays.asList(value)));
getHeaders().keySet().removeAll(namesToRemove);
getHeaders().put(name, new ArrayList<>(Arrays.asList(value)));
}
@Override
......
......@@ -37,6 +37,7 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
......@@ -102,4 +103,15 @@ void verifyStub() throws Exception
assertFalse(stub.isUserInRole("developer"));
assertTrue(stub.isUserInRole("tester"));
}
@Test
void staleRequest()
{
HandshakeRequest handshakeRequest = mock(HandshakeRequest.class);
when(handshakeRequest.isUserInRole(anyString())).thenThrow(new RuntimeException("Stale request"));
XWikiWebSocketRequestStub stub = new XWikiWebSocketRequestStub(handshakeRequest);
assertFalse(stub.isUserInRole("admin"));
}
}
......@@ -35,6 +35,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
......@@ -47,7 +48,7 @@
class XWikiWebSocketResponseStubTest
{
@Test
void verifyStub() throws Exception
void verifyStub()
{
Map<String, List<String>> headers = new LinkedHashMap<>();
......@@ -90,4 +91,15 @@ void verifyStub() throws Exception
assertTrue(stub.containsHeader("dATe"));
assertFalse(stub.containsHeader("Age"));
}
@Test
void staleResponse()
{
HandshakeResponse handshakeResponse = mock(HandshakeResponse.class);
when(handshakeResponse.getHeaders()).thenThrow(new RuntimeException("Stale response"));
XWikiWebSocketResponseStub stub = new XWikiWebSocketResponseStub(handshakeResponse);
assertNull(stub.getHeader("foo"));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment