/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.internal.server.websocket;

import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.ResponseHeadersBuilder;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.stream.ClosedStreamException;
import com.linecorp.armeria.common.stream.StreamMessage;
import com.linecorp.armeria.common.util.TimeoutMode;
import com.linecorp.armeria.common.websocket.WebSocket;
import com.linecorp.armeria.common.websocket.WebSocketFrame;
import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder;
import com.linecorp.armeria.internal.common.websocket.WebSocketUtil;
import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper;
import com.linecorp.armeria.internal.server.websocket.WebSocketServiceFrameDecoder;
import com.linecorp.armeria.internal.shaded.guava.base.Ascii;
import com.linecorp.armeria.internal.shaded.guava.base.Splitter;
import com.linecorp.armeria.internal.shaded.guava.net.HostAndPort;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.ServiceOptions;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.websocket.WebSocketProtocolHandler;
import com.linecorp.armeria.server.websocket.WebSocketService;
import com.linecorp.armeria.server.websocket.WebSocketServiceHandler;
import com.linecorp.armeria.server.websocket.WebSocketUpgradeResult;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.util.AttributeKey;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class DefaultWebSocketService
implements WebSocketProtocolHandler,
WebSocketService {
    private static final Logger logger = LoggerFactory.getLogger(DefaultWebSocketService.class);
    private static final AttributeKey<WebSocketServiceFrameDecoder> DECODER = AttributeKey.valueOf(DefaultWebSocketService.class, "DECODER");
    private static final String SUB_PROTOCOL_WILDCARD = "*";
    private static final ResponseHeaders UNSUPPORTED_WEB_SOCKET_VERSION = ResponseHeaders.builder(HttpStatus.BAD_REQUEST).add((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue()).contentType(MediaType.PLAIN_TEXT_UTF_8).build();
    private static final Splitter commaSplitter = Splitter.on(',').trimResults().omitEmptyStrings();
    private static final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(false);
    private final WebSocketServiceHandler handler;
    @Nullable
    private final HttpService fallbackService;
    private final int maxFramePayloadLength;
    private final boolean allowMaskMismatch;
    private final Set<String> subprotocols;
    private final boolean allowAnyOrigin;
    @Nullable
    private final Predicate<? super String> originPredicate;
    private final boolean aggregateContinuation;
    private final ServiceOptions serviceOptions;

    public DefaultWebSocketService(WebSocketServiceHandler handler, @Nullable HttpService fallbackService, int maxFramePayloadLength, boolean allowMaskMismatch, Set<String> subprotocols, boolean allowAnyOrigin, @Nullable Predicate<? super String> originPredicate, boolean aggregateContinuation, ServiceOptions serviceOptions) {
        this.handler = handler;
        this.fallbackService = fallbackService;
        this.maxFramePayloadLength = maxFramePayloadLength;
        this.allowMaskMismatch = allowMaskMismatch;
        this.subprotocols = subprotocols;
        this.allowAnyOrigin = allowAnyOrigin;
        this.originPredicate = originPredicate;
        this.aggregateContinuation = aggregateContinuation;
        this.serviceOptions = serviceOptions;
    }

    @Override
    public WebSocket serve(ServiceRequestContext ctx, WebSocket in) throws Exception {
        return this.handler.handle(ctx, in);
    }

    @Override
    public void serviceAdded(ServiceConfig cfg) throws Exception {
        if (this.fallbackService != null) {
            this.fallbackService.serviceAdded(cfg);
        }
    }

    @Override
    public WebSocketUpgradeResult upgrade(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        HttpMethod method = ctx.method();
        switch (method) {
            case GET: {
                return this.upgradeHttp1(ctx, req);
            }
            case CONNECT: {
                return this.upgradeHttp2(ctx, req);
            }
        }
        HttpResponse httpResponse = this.failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED));
        return WebSocketUpgradeResult.ofFailure(httpResponse);
    }

    private WebSocketUpgradeResult upgradeHttp1(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        if (!ctx.sessionProtocol().isExplicitHttp1()) {
            HttpResponse httpResponse = this.failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED));
            return WebSocketUpgradeResult.ofFailure(httpResponse);
        }
        RequestHeaders headers = req.headers();
        if (!WebSocketUtil.isHttp1WebSocketUpgradeRequest(headers)) {
            HttpResponse httpResponse = this.failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "The upgrade header must contain:\n  Upgrade: websocket\n  Connection: Upgrade"));
            return WebSocketUpgradeResult.ofFailure(httpResponse);
        }
        HttpResponse invalidResponse = this.checkOrigin(ctx, headers);
        if (invalidResponse != null) {
            return WebSocketUpgradeResult.ofFailure(invalidResponse);
        }
        invalidResponse = DefaultWebSocketService.checkVersion(headers);
        if (invalidResponse != null) {
            return WebSocketUpgradeResult.ofFailure(invalidResponse);
        }
        String webSocketKey = headers.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, "");
        if (webSocketKey.isEmpty()) {
            return WebSocketUpgradeResult.ofFailure(HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "missing Sec-WebSocket-Key header"));
        }
        return WebSocketUpgradeResult.ofSuccess();
    }

    private HttpResponse failOrFallback(ServiceRequestContext ctx, HttpRequest req, Supplier<HttpResponse> invalidResponse) throws Exception {
        if (this.fallbackService != null) {
            ServiceOptions options = this.fallbackService.options();
            long requestTimeoutMillis = options.requestTimeoutMillis();
            if (requestTimeoutMillis < 0L) {
                requestTimeoutMillis = ctx.config().virtualHost().requestTimeoutMillis();
            }
            ctx.setRequestTimeoutMillis(TimeoutMode.SET_FROM_START, requestTimeoutMillis);
            long maxRequestLength = options.maxRequestLength();
            if (maxRequestLength < 0L) {
                maxRequestLength = ctx.config().virtualHost().maxRequestLength();
            }
            ctx.setMaxRequestLength(maxRequestLength);
            long requestAutoAbortDelayMillis = options.requestAutoAbortDelayMillis();
            if (requestAutoAbortDelayMillis < 0L) {
                requestAutoAbortDelayMillis = ctx.config().virtualHost().requestAutoAbortDelayMillis();
            }
            ctx.setRequestAutoAbortDelayMillis(requestAutoAbortDelayMillis);
            return this.fallbackService.serve(ctx, req);
        }
        return invalidResponse.get();
    }

    private void maybeAddSubprotocol(RequestHeaders headers, ResponseHeadersBuilder responseHeadersBuilder) {
        String subprotocols = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "");
        if (subprotocols.isEmpty()) {
            return;
        }
        commaSplitter.splitToStream(subprotocols).filter(sub -> SUB_PROTOCOL_WILDCARD.equals(sub) || this.subprotocols.contains(sub)).findFirst().ifPresent(selectedSubprotocol -> responseHeadersBuilder.add((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, (String)selectedSubprotocol));
    }

    private WebSocketUpgradeResult upgradeHttp2(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        if (!ctx.sessionProtocol().isExplicitHttp2()) {
            HttpResponse fallbackResponse = this.failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED));
            return WebSocketUpgradeResult.ofFailure(fallbackResponse);
        }
        RequestHeaders headers = req.headers();
        if (!WebSocketUtil.isHttp2WebSocketUpgradeRequest(headers)) {
            logger.trace("RequestHeaders does not contain headers for WebSocket upgrade. headers: {}", (Object)headers);
            HttpResponse fallbackResponse = this.failOrFallback(ctx, req, () -> HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "The upgrade header must contain:\n  :protocol = websocket"));
            return WebSocketUpgradeResult.ofFailure(fallbackResponse);
        }
        HttpResponse invalidResponse = this.checkOrigin(ctx, headers);
        if (invalidResponse != null) {
            return WebSocketUpgradeResult.ofFailure(invalidResponse);
        }
        invalidResponse = DefaultWebSocketService.checkVersion(headers);
        if (invalidResponse != null) {
            return WebSocketUpgradeResult.ofFailure(invalidResponse);
        }
        return WebSocketUpgradeResult.ofSuccess();
    }

    @Nullable
    private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders headers) {
        if (this.allowAnyOrigin) {
            return null;
        }
        String origin = headers.get(HttpHeaderNames.ORIGIN, "");
        if (origin.isEmpty()) {
            return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "missing the origin header");
        }
        String lowerCaseOrigin = Ascii.toLowerCase(origin);
        if (this.originPredicate == null) {
            if (!DefaultWebSocketService.isSameOrigin(ctx, headers, lowerCaseOrigin)) {
                return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "not allowed origin: " + lowerCaseOrigin);
            }
            return null;
        }
        if (!this.originPredicate.test(lowerCaseOrigin)) {
            return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "not allowed origin: " + lowerCaseOrigin);
        }
        return null;
    }

    private static boolean isSameOrigin(ServiceRequestContext ctx, RequestHeaders headers, String origin) {
        int schemeDelimiter = origin.indexOf("://");
        if (schemeDelimiter < 0) {
            return false;
        }
        String scheme = origin.substring(0, schemeDelimiter);
        SessionProtocol originSessionProtocol = SessionProtocol.find(scheme);
        if (originSessionProtocol == null) {
            return false;
        }
        if (!(ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp() || ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) {
            return false;
        }
        String authority = headers.authority();
        assert (authority != null);
        HostAndPort authorityHostAndPort = HostAndPort.fromString(authority);
        String authorityHost = authorityHostAndPort.getHost();
        int authorityPort = authorityHostAndPort.getPortOrDefault(ctx.sessionProtocol().defaultPort());
        HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3));
        String originHost = originHostAndPort.getHost();
        int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort());
        return authorityPort == originPort && authorityHost.equals(originHost);
    }

    @Nullable
    private static HttpResponse checkVersion(RequestHeaders headers) {
        String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION);
        if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) {
            return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, HttpData.ofUtf8("Only 13 version is supported."));
        }
        return null;
    }

    @Override
    public WebSocket decode(ServiceRequestContext ctx, HttpRequest req) {
        WebSocketServiceFrameDecoder decoder = new WebSocketServiceFrameDecoder(ctx, this.maxFramePayloadLength, this.allowMaskMismatch, this.aggregateContinuation);
        ctx.setAttr(DECODER, decoder);
        return new WebSocketWrapper(req.decode(decoder, ctx.alloc()));
    }

    @Override
    public HttpResponse encode(ServiceRequestContext ctx, WebSocket out) {
        ResponseHeadersBuilder responseHeadersBuilder;
        RequestHeaders requestHeaders = ctx.request().headers();
        if (ctx.sessionProtocol().isExplicitHttp1()) {
            String webSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, "");
            String accept = WebSocketUtil.generateSecWebSocketAccept(webSocketKey);
            responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS).add((CharSequence)HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()).add((CharSequence)HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()).add((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
        } else {
            responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK);
        }
        this.maybeAddSubprotocol(requestHeaders, responseHeadersBuilder);
        WebSocketServiceFrameDecoder decoder = ctx.attr(DECODER);
        assert (decoder != null);
        decoder.setOutboundWebSocket(out);
        StreamMessage<HttpData> data = out.recoverAndResume(cause -> {
            if (cause instanceof ClosedStreamException) {
                return StreamMessage.aborted(cause);
            }
            ctx.logBuilder().responseCause((Throwable)cause);
            return StreamMessage.of(WebSocketUtil.newCloseWebSocketFrame(cause));
        }).map(frame -> HttpData.wrap(encoder.encode(ctx, (WebSocketFrame)frame)));
        return HttpResponse.of(responseHeadersBuilder.build(), data);
    }

    @Override
    public WebSocketProtocolHandler protocolHandler() {
        return this;
    }

    @Override
    public ServiceOptions options() {
        return this.serviceOptions;
    }
}

