/* * JBoss, Home of Professional Open Source. * Copyright 2012 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.servlet.spec; import io.undertow.Version; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.resource.Resource; import io.undertow.server.session.PathParameterSessionConfig; import io.undertow.server.session.Session; import io.undertow.server.session.SessionConfig; import io.undertow.server.session.SessionManager; import io.undertow.server.session.SslSessionConfig; import io.undertow.servlet.UndertowServletLogger; import io.undertow.servlet.UndertowServletMessages; import io.undertow.servlet.api.Deployment; import io.undertow.servlet.api.DeploymentInfo; import io.undertow.servlet.api.DeploymentManager; import io.undertow.servlet.api.FilterInfo; import io.undertow.servlet.api.HttpMethodSecurityInfo; import io.undertow.servlet.api.InstanceFactory; import io.undertow.servlet.api.ListenerInfo; import io.undertow.servlet.api.SecurityInfo; import io.undertow.servlet.api.ServletContainer; import io.undertow.servlet.api.ServletInfo; import io.undertow.servlet.api.ServletSecurityInfo; import io.undertow.servlet.api.SessionConfigWrapper; import io.undertow.servlet.api.TransportGuaranteeType; import io.undertow.servlet.core.ApplicationListeners; import io.undertow.servlet.core.ManagedListener; import io.undertow.servlet.handlers.ServletChain; import io.undertow.servlet.util.EmptyEnumeration; import io.undertow.servlet.util.ImmediateInstanceFactory; import io.undertow.servlet.util.IteratorEnumeration; import io.undertow.util.AttachmentKey; import javax.annotation.security.DeclareRoles; import javax.annotation.security.RunAs; import javax.servlet.Filter; import javax.servlet.FilterRegistration; import javax.servlet.MultipartConfigElement; import javax.servlet.RequestDispatcher; import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletContextListener; import javax.servlet.ServletException; import javax.servlet.ServletRegistration; import javax.servlet.SessionTrackingMode; import javax.servlet.annotation.HttpMethodConstraint; import javax.servlet.annotation.MultipartConfig; import javax.servlet.annotation.ServletSecurity; import javax.servlet.descriptor.JspConfigDescriptor; import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; import java.net.URL; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Collections; import java.util.Enumeration; import java.util.EventListener; import java.util.HashMap; import java.util.HashSet; import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import static io.undertow.servlet.core.ApplicationListeners.ListenerState.NO_LISTENER; import static io.undertow.servlet.core.ApplicationListeners.ListenerState.PROGRAMATIC_LISTENER; /** * @author Stuart Douglas */ public class ServletContextImpl implements ServletContext { private final ServletContainer servletContainer; private final Deployment deployment; private DeploymentInfo deploymentInfo; private final ConcurrentMap attributes; private final SessionCookieConfigImpl sessionCookieConfig; private final AttachmentKey sessionAttachmentKey = AttachmentKey.create(HttpSessionImpl.class); private volatile Set sessionTrackingModes = new HashSet(Arrays.asList(new SessionTrackingMode[]{SessionTrackingMode.COOKIE, SessionTrackingMode.URL})); private volatile Set defaultSessionTrackingModes = new HashSet(Arrays.asList(new SessionTrackingMode[]{SessionTrackingMode.COOKIE, SessionTrackingMode.URL})); private volatile SessionConfig sessionConfig; private volatile boolean initialized = false; public ServletContextImpl(final ServletContainer servletContainer, final Deployment deployment) { this.servletContainer = servletContainer; this.deployment = deployment; this.deploymentInfo = deployment.getDeploymentInfo(); sessionCookieConfig = new SessionCookieConfigImpl(this); sessionCookieConfig.setPath(deploymentInfo.getContextPath()); if (deploymentInfo.getServletContextAttributeBackingMap() == null) { this.attributes = new ConcurrentHashMap(); } else { this.attributes = deploymentInfo.getServletContextAttributeBackingMap(); } attributes.putAll(deployment.getDeploymentInfo().getServletContextAttributes()); } public void initDone() { initialized = true; Set trackingMethods = sessionTrackingModes; SessionConfig sessionConfig = sessionCookieConfig; if (trackingMethods != null && !trackingMethods.isEmpty()) { if (sessionTrackingModes.contains(SessionTrackingMode.SSL)) { sessionConfig = new SslSessionConfig(deployment.getSessionManager()); } else { if (sessionTrackingModes.contains(SessionTrackingMode.COOKIE) && sessionTrackingModes.contains(SessionTrackingMode.URL)) { sessionCookieConfig.setFallback(new PathParameterSessionConfig(sessionCookieConfig.getName().toLowerCase(Locale.ENGLISH))); } else if (sessionTrackingModes.contains(SessionTrackingMode.URL)) { sessionConfig = new PathParameterSessionConfig(sessionCookieConfig.getName().toLowerCase(Locale.ENGLISH)); } } } SessionConfigWrapper wrapper = deploymentInfo.getSessionConfigWrapper(); if (wrapper != null) { sessionConfig = wrapper.wrap(sessionConfig, deployment); } this.sessionConfig = sessionConfig; } @Override public String getContextPath() { return deploymentInfo.getContextPath(); } @Override public ServletContext getContext(final String uripath) { DeploymentManager deploymentByPath = servletContainer.getDeploymentByPath(uripath); if (deploymentByPath == null) { return null; } return deploymentByPath.getDeployment().getServletContext(); } @Override public int getMajorVersion() { return 3; } @Override public int getMinorVersion() { return 1; } @Override public int getEffectiveMajorVersion() { return deploymentInfo.getMajorVersion(); } @Override public int getEffectiveMinorVersion() { return deploymentInfo.getMinorVersion(); } @Override public String getMimeType(final String file) { int pos = file.lastIndexOf('.'); if (pos == -1) { return deployment.getMimeExtensionMappings().get(file); } return deployment.getMimeExtensionMappings().get(file.substring(pos + 1)); } @Override public Set getResourcePaths(final String path) { final Resource resource; try { resource = deploymentInfo.getResourceManager().getResource(path); } catch (IOException e) { return null; } if (resource == null || !resource.isDirectory()) { return null; } final Set resources = new HashSet(); for (Resource res : resource.list()) { File file = res.getFile(); if (file != null) { File base = res.getResourceManagerRoot(); if (base == null) { resources.add(file.getPath()); //not much else we can do here } else { String filePath = file.getAbsolutePath().substring(base.getAbsolutePath().length()); filePath = filePath.replace('\\', '/'); //for windows systems if (file.isDirectory()) { filePath = filePath + "/"; } resources.add(filePath); } } } return resources; } @Override public URL getResource(final String path) throws MalformedURLException { if (!path.startsWith("/")) { throw UndertowServletMessages.MESSAGES.pathMustStartWithSlash(path); } Resource resource = null; try { resource = deploymentInfo.getResourceManager().getResource(path); } catch (IOException e) { return null; } if (resource == null) { return null; } return resource.getUrl(); } @Override public InputStream getResourceAsStream(final String path) { Resource resource = null; try { resource = deploymentInfo.getResourceManager().getResource(path); } catch (IOException e) { return null; } if (resource == null) { return null; } try { if (resource.getFile() != null) { return new BufferedInputStream(new FileInputStream(resource.getFile())); } else { return new BufferedInputStream(resource.getUrl().openStream()); } } catch (FileNotFoundException e) { //should never happen, as the resource loader should return null in this case return null; } catch (IOException e) { return null; } } @Override public RequestDispatcher getRequestDispatcher(final String path) { return new RequestDispatcherImpl(path, this); } @Override public RequestDispatcher getNamedDispatcher(final String name) { ServletChain chain = deployment.getServletPaths().getServletHandlerByName(name); if (chain != null) { return new RequestDispatcherImpl(chain, this); } else { return null; } } @Override public Servlet getServlet(final String name) throws ServletException { return deployment.getServletPaths().getServletHandlerByName(name).getManagedServlet().getServlet().getInstance(); } @Override public Enumeration getServlets() { return EmptyEnumeration.instance(); } @Override public Enumeration getServletNames() { return EmptyEnumeration.instance(); } @Override public void log(final String msg) { UndertowServletLogger.ROOT_LOGGER.info(msg); } @Override public void log(final Exception exception, final String msg) { UndertowServletLogger.ROOT_LOGGER.error(msg, exception); } @Override public void log(final String message, final Throwable throwable) { UndertowServletLogger.ROOT_LOGGER.error(message, throwable); } @Override public String getRealPath(final String path) { if (path == null) { return null; } Resource resource = null; try { resource = deploymentInfo.getResourceManager().getResource(path); } catch (IOException e) { return null; } if (resource == null) { return null; } File file = resource.getFile(); if (file == null) { return null; } return file.getAbsolutePath(); } @Override public String getServerInfo() { return deploymentInfo.getServerName() + " - " + Version.getVersionString(); } @Override public String getInitParameter(final String name) { if (name == null) { throw UndertowServletMessages.MESSAGES.nullName(); } return deploymentInfo.getInitParameters().get(name); } @Override public Enumeration getInitParameterNames() { return new IteratorEnumeration(deploymentInfo.getInitParameters().keySet().iterator()); } @Override public boolean setInitParameter(final String name, final String value) { if (deploymentInfo.getInitParameters().containsKey(name)) { return false; } deploymentInfo.addInitParameter(name, value); return true; } @Override public Object getAttribute(final String name) { return attributes.get(name); } @Override public Enumeration getAttributeNames() { return new IteratorEnumeration(attributes.keySet().iterator()); } @Override public void setAttribute(final String name, final Object object) { if (object == null) { Object existing = attributes.remove(name); if (deployment.getApplicationListeners() != null) { if (existing != null) { deployment.getApplicationListeners().servletContextAttributeRemoved(name, existing); } } } else { Object existing = attributes.put(name, object); if (deployment.getApplicationListeners() != null) { if (existing != null) { deployment.getApplicationListeners().servletContextAttributeReplaced(name, existing); } else { deployment.getApplicationListeners().servletContextAttributeAdded(name, object); } } } } @Override public void removeAttribute(final String name) { Object exiting = attributes.remove(name); deployment.getApplicationListeners().servletContextAttributeRemoved(name, exiting); } @Override public String getServletContextName() { return deploymentInfo.getDisplayName(); } @Override public ServletRegistration.Dynamic addServlet(final String servletName, final String className) { ensureNotProgramaticListener(); ensureNotInitialized(); try { if (deploymentInfo.getServlets().containsKey(servletName)) { return null; } ServletInfo servlet = new ServletInfo(servletName, (Class) deploymentInfo.getClassLoader().loadClass(className)); readServletAnnotations(servlet); deploymentInfo.addServlet(servlet); deployment.getServlets().addServlet(servlet); return new ServletRegistrationImpl(servlet, deployment); } catch (ClassNotFoundException e) { throw UndertowServletMessages.MESSAGES.cannotLoadClass(className, e); } } @Override public ServletRegistration.Dynamic addServlet(final String servletName, final Servlet servlet) { ensureNotProgramaticListener(); ensureNotInitialized(); if (deploymentInfo.getServlets().containsKey(servletName)) { return null; } ServletInfo s = new ServletInfo(servletName, servlet.getClass(), new ImmediateInstanceFactory(servlet)); readServletAnnotations(s); deploymentInfo.addServlet(s); deployment.getServlets().addServlet(s); return new ServletRegistrationImpl(s, deployment); } @Override public ServletRegistration.Dynamic addServlet(final String servletName, final Class servletClass) { ensureNotProgramaticListener(); ensureNotInitialized(); if (deploymentInfo.getServlets().containsKey(servletName)) { return null; } ServletInfo servlet = new ServletInfo(servletName, servletClass); readServletAnnotations(servlet); deploymentInfo.addServlet(servlet); deployment.getServlets().addServlet(servlet); return new ServletRegistrationImpl(servlet, deployment); } @Override public T createServlet(final Class clazz) throws ServletException { ensureNotProgramaticListener(); try { return deploymentInfo.getClassIntrospecter().createInstanceFactory(clazz).createInstance().getInstance(); } catch (Exception e) { throw UndertowServletMessages.MESSAGES.couldNotInstantiateComponent(clazz.getName(), e); } } @Override public ServletRegistration getServletRegistration(final String servletName) { ensureNotProgramaticListener(); final ServletInfo servlet = deploymentInfo.getServlets().get(servletName); if (servlet == null) { return null; } return new ServletRegistrationImpl(servlet, deployment); } @Override public Map getServletRegistrations() { ensureNotProgramaticListener(); final Map ret = new HashMap(); for (Map.Entry entry : deploymentInfo.getServlets().entrySet()) { ret.put(entry.getKey(), new ServletRegistrationImpl(entry.getValue(), deployment)); } return ret; } @Override public FilterRegistration.Dynamic addFilter(final String filterName, final String className) { ensureNotProgramaticListener(); ensureNotInitialized(); if (deploymentInfo.getFilters().containsKey(filterName)) { return null; } try { FilterInfo filter = new FilterInfo(filterName, (Class) deploymentInfo.getClassLoader().loadClass(className)); deploymentInfo.addFilter(filter); deployment.getFilters().addFilter(filter); return new FilterRegistrationImpl(filter, deployment); } catch (ClassNotFoundException e) { throw UndertowServletMessages.MESSAGES.cannotLoadClass(className, e); } } @Override public FilterRegistration.Dynamic addFilter(final String filterName, final Filter filter) { ensureNotProgramaticListener(); ensureNotInitialized(); if (deploymentInfo.getFilters().containsKey(filterName)) { return null; } FilterInfo f = new FilterInfo(filterName, filter.getClass(), new ImmediateInstanceFactory(filter)); deploymentInfo.addFilter(f); deployment.getFilters().addFilter(f); return new FilterRegistrationImpl(f, deployment); } @Override public FilterRegistration.Dynamic addFilter(final String filterName, final Class filterClass) { ensureNotProgramaticListener(); ensureNotInitialized(); if (deploymentInfo.getFilters().containsKey(filterName)) { return null; } FilterInfo filter = new FilterInfo(filterName, filterClass); deploymentInfo.addFilter(filter); deployment.getFilters().addFilter(filter); return new FilterRegistrationImpl(filter, deployment); } @Override public T createFilter(final Class clazz) throws ServletException { ensureNotProgramaticListener(); try { return deploymentInfo.getClassIntrospecter().createInstanceFactory(clazz).createInstance().getInstance(); } catch (Exception e) { throw UndertowServletMessages.MESSAGES.couldNotInstantiateComponent(clazz.getName(), e); } } @Override public FilterRegistration getFilterRegistration(final String filterName) { ensureNotProgramaticListener(); final FilterInfo filterInfo = deploymentInfo.getFilters().get(filterName); if (filterInfo == null) { return null; } return new FilterRegistrationImpl(filterInfo, deployment); } @Override public Map getFilterRegistrations() { ensureNotProgramaticListener(); final Map ret = new HashMap(); for (Map.Entry entry : deploymentInfo.getFilters().entrySet()) { ret.put(entry.getKey(), new FilterRegistrationImpl(entry.getValue(), deployment)); } return ret; } @Override public SessionCookieConfigImpl getSessionCookieConfig() { ensureNotProgramaticListener(); return sessionCookieConfig; } @Override public void setSessionTrackingModes(final Set sessionTrackingModes) { ensureNotProgramaticListener(); ensureNotInitialized(); if (sessionTrackingModes.size() > 1 && sessionTrackingModes.contains(SessionTrackingMode.SSL)) { throw UndertowServletMessages.MESSAGES.sslCannotBeCombinedWithAnyOtherMethod(); } this.sessionTrackingModes = new HashSet(sessionTrackingModes); //TODO: actually make this work } @Override public Set getDefaultSessionTrackingModes() { ensureNotProgramaticListener(); return defaultSessionTrackingModes; } @Override public Set getEffectiveSessionTrackingModes() { ensureNotProgramaticListener(); return Collections.unmodifiableSet(sessionTrackingModes); } @Override public void addListener(final String className) { try { Class clazz = (Class) deploymentInfo.getClassLoader().loadClass(className); addListener(clazz); } catch (ClassNotFoundException e) { throw new IllegalArgumentException(e); } } @Override public void addListener(final T t) { ensureNotInitialized(); ensureNotProgramaticListener(); if (ApplicationListeners.listenerState() != NO_LISTENER && ServletContextListener.class.isAssignableFrom(t.getClass())) { throw UndertowServletMessages.MESSAGES.cannotAddServletContextListener(); } ListenerInfo listener = new ListenerInfo(t.getClass(), new ImmediateInstanceFactory(t)); deploymentInfo.addListener(listener); deployment.getApplicationListeners().addListener(new ManagedListener(listener, true)); } @Override public void addListener(final Class listenerClass) { ensureNotInitialized(); ensureNotProgramaticListener(); if (ApplicationListeners.listenerState() != NO_LISTENER && ServletContextListener.class.isAssignableFrom(listenerClass)) { throw UndertowServletMessages.MESSAGES.cannotAddServletContextListener(); } InstanceFactory factory = null; try { factory = deploymentInfo.getClassIntrospecter().createInstanceFactory(listenerClass); } catch (Exception e) { throw new IllegalArgumentException(e); } final ListenerInfo listener = new ListenerInfo(listenerClass, factory); deploymentInfo.addListener(listener); deployment.getApplicationListeners().addListener(new ManagedListener(listener, true)); } @Override public T createListener(final Class clazz) throws ServletException { ensureNotProgramaticListener(); if (!ApplicationListeners.isListenerClass(clazz)) { throw UndertowServletMessages.MESSAGES.listenerMustImplementListenerClass(clazz); } try { return deploymentInfo.getClassIntrospecter().createInstanceFactory(clazz).createInstance().getInstance(); } catch (Exception e) { throw UndertowServletMessages.MESSAGES.couldNotInstantiateComponent(clazz.getName(), e); } } @Override public JspConfigDescriptor getJspConfigDescriptor() { return deploymentInfo.getJspConfigDescriptor(); } @Override public ClassLoader getClassLoader() { return deploymentInfo.getClassLoader(); } @Override public void declareRoles(final String... roleNames) { } @Override public String getVirtualServerName() { return deployment.getDeploymentInfo().getHostName(); } /** * Gets the session with the specified ID if it exists * * @param sessionId The session ID * @return The session */ public HttpSessionImpl getSession(final String sessionId) { final SessionManager sessionManager = deployment.getSessionManager(); Session session = sessionManager.getSession(sessionId); if (session != null) { return SecurityActions.forSession(session, this, false); } return null; } public HttpSessionImpl getSession(final ServletContextImpl originalServletContext, final HttpServerExchange exchange, boolean create) { SessionConfig c = originalServletContext.getSessionConfig(); HttpSessionImpl httpSession = exchange.getAttachment(sessionAttachmentKey); if (httpSession != null && httpSession.isInvalid()) { exchange.removeAttachment(sessionAttachmentKey); httpSession = null; } if (httpSession == null) { final SessionManager sessionManager = deployment.getSessionManager(); Session session = sessionManager.getSession(exchange, c); if (session != null) { httpSession = SecurityActions.forSession(session, this, false); exchange.putAttachment(sessionAttachmentKey, httpSession); } else if (create) { String existing = c.findSessionId(exchange); if (originalServletContext != this) { //this is a cross context request //we need to make sure there is a top level session originalServletContext.getSession(originalServletContext, exchange, true); } else if (existing != null) { c.clearSession(exchange, existing); } final Session newSession = sessionManager.createSession(exchange, c); httpSession = SecurityActions.forSession(newSession, this, true); exchange.putAttachment(sessionAttachmentKey, httpSession); } } return httpSession; } /** * Gets the session * * @param create * @return */ public HttpSessionImpl getSession(final HttpServerExchange exchange, boolean create) { return getSession(this, exchange, create); } public void updateSessionAccessTime(final HttpServerExchange exchange) { HttpSessionImpl httpSession = getSession(exchange, false); if (httpSession != null) { Session underlyingSession; if (System.getSecurityManager() == null) { underlyingSession = httpSession.getSession(); } else { underlyingSession = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession)); } underlyingSession.requestDone(exchange); } } public Deployment getDeployment() { return deployment; } private void ensureNotInitialized() { if (initialized) { throw UndertowServletMessages.MESSAGES.servletContextAlreadyInitialized(); } } private void ensureNotProgramaticListener() { if (ApplicationListeners.listenerState() == PROGRAMATIC_LISTENER) { throw UndertowServletMessages.MESSAGES.cannotCallFromProgramaticListener(); } } boolean isInitialized() { return initialized; } public SessionConfig getSessionConfig() { return sessionConfig; } public void destroy() { attributes.clear(); deploymentInfo = null; } private void readServletAnnotations(ServletInfo servlet) { if (System.getSecurityManager() == null) { new ReadServletAnnotationsTask(servlet, deploymentInfo).run(); } else { AccessController.doPrivileged(new ReadServletAnnotationsTask(servlet, deploymentInfo)); } } public void setDefaultSessionTrackingModes(HashSet sessionTrackingModes) { this.defaultSessionTrackingModes = sessionTrackingModes; this.sessionTrackingModes = sessionTrackingModes; } private static final class ReadServletAnnotationsTask implements PrivilegedAction { private final ServletInfo servletInfo; private final DeploymentInfo deploymentInfo; private ReadServletAnnotationsTask(ServletInfo servletInfo, DeploymentInfo deploymentInfo) { this.servletInfo = servletInfo; this.deploymentInfo = deploymentInfo; } @Override public Void run() { final ServletSecurity security = servletInfo.getServletClass().getAnnotation(ServletSecurity.class); if (security != null) { ServletSecurityInfo servletSecurityInfo = new ServletSecurityInfo() .setEmptyRoleSemantic(security.value().value() == ServletSecurity.EmptyRoleSemantic.DENY ? SecurityInfo.EmptyRoleSemantic.DENY : SecurityInfo.EmptyRoleSemantic.PERMIT) .setTransportGuaranteeType(security.value().transportGuarantee() == ServletSecurity.TransportGuarantee.CONFIDENTIAL ? TransportGuaranteeType.CONFIDENTIAL : TransportGuaranteeType.NONE) .addRolesAllowed(security.value().rolesAllowed()); for (HttpMethodConstraint constraint : security.httpMethodConstraints()) { servletSecurityInfo.addHttpMethodSecurityInfo(new HttpMethodSecurityInfo() .setMethod(constraint.value())) .setEmptyRoleSemantic(constraint.emptyRoleSemantic() == ServletSecurity.EmptyRoleSemantic.DENY ? SecurityInfo.EmptyRoleSemantic.DENY : SecurityInfo.EmptyRoleSemantic.PERMIT) .setTransportGuaranteeType(constraint.transportGuarantee() == ServletSecurity.TransportGuarantee.CONFIDENTIAL ? TransportGuaranteeType.CONFIDENTIAL : TransportGuaranteeType.NONE) .addRolesAllowed(constraint.rolesAllowed()); } servletInfo.setServletSecurityInfo(servletSecurityInfo); } final MultipartConfig multipartConfig = servletInfo.getServletClass().getAnnotation(MultipartConfig.class); if (multipartConfig != null) { servletInfo.setMultipartConfig(new MultipartConfigElement(multipartConfig.location(), multipartConfig.maxFileSize(), multipartConfig.maxRequestSize(), multipartConfig.fileSizeThreshold())); } final RunAs runAs = servletInfo.getServletClass().getAnnotation(RunAs.class); if (runAs != null) { servletInfo.setRunAs(runAs.value()); } final DeclareRoles declareRoles = servletInfo.getServletClass().getAnnotation(DeclareRoles.class); if (declareRoles != null) { deploymentInfo.addSecurityRoles(declareRoles.value()); } return null; } } }