/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.client.transport;

import io.modelcontextprotocol.client.transport.ResponseSubscribers;
import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer;
import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransportException;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

public class HttpClientSseClientTransport
implements McpClientTransport {
    private static final String MCP_PROTOCOL_VERSION = "2024-11-05";
    private static final String MCP_PROTOCOL_VERSION_HEADER_NAME = "MCP-Protocol-Version";
    private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class);
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private final URI baseUri;
    private final String sseEndpoint;
    private final HttpClient httpClient;
    private final HttpRequest.Builder requestBuilder;
    protected McpJsonMapper jsonMapper;
    private volatile boolean isClosing = false;
    private final AtomicReference<Disposable> sseSubscription = new AtomicReference();
    protected final Sinks.One<String> messageEndpointSink = Sinks.one();
    private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer;

    HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) {
        Assert.notNull(jsonMapper, "jsonMapper must not be null");
        Assert.hasText(baseUri, "baseUri must not be empty");
        Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
        Assert.notNull(httpClient, "httpClient must not be null");
        Assert.notNull(requestBuilder, "requestBuilder must not be null");
        Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null");
        this.baseUri = URI.create(baseUri);
        this.sseEndpoint = sseEndpoint;
        this.jsonMapper = jsonMapper;
        this.httpClient = httpClient;
        this.requestBuilder = requestBuilder;
        this.httpRequestCustomizer = httpRequestCustomizer;
    }

    @Override
    public List<String> protocolVersions() {
        return List.of(MCP_PROTOCOL_VERSION);
    }

    public static Builder builder(String baseUri) {
        return new Builder().baseUri(baseUri);
    }

    @Override
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        URI uri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
        return Mono.deferContextual(ctx -> {
            HttpRequest.Builder builder = this.requestBuilder.copy().uri(uri).header("Accept", "text/event-stream").header("Cache-Control", "no-cache").header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION).GET();
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext));
        }).flatMap(requestBuilder -> Mono.create(sink -> {
            Disposable connection = Flux.create(sseSink -> this.httpClient.sendAsync(requestBuilder.build(), responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, (FluxSink<ResponseSubscribers.ResponseEvent>)sseSink)).exceptionallyCompose(e -> {
                sseSink.error(e);
                return CompletableFuture.failedFuture(e);
            })).map(responseEvent -> (ResponseSubscribers.SseResponseEvent)responseEvent).flatMap(responseEvent -> {
                if (this.isClosing) {
                    return Mono.empty();
                }
                int statusCode = responseEvent.responseInfo().statusCode();
                if (statusCode >= 200 && statusCode < 300) {
                    try {
                        if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
                            String messageEndpointUri = responseEvent.sseEvent().data();
                            if (this.messageEndpointSink.tryEmitValue((Object)messageEndpointUri).isSuccess()) {
                                sink.success();
                                return Flux.empty();
                            }
                            sink.error((Throwable)new RuntimeException("Failed to handle SSE endpoint event"));
                        } else {
                            if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
                                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, responseEvent.sseEvent().data());
                                sink.success();
                                return Flux.just((Object)message);
                            }
                            logger.debug("Received unrecognized SSE event type: {}", (Object)responseEvent.sseEvent());
                            sink.success();
                        }
                    }
                    catch (IOException e) {
                        sink.error((Throwable)new McpTransportException("Error processing SSE event", e));
                    }
                }
                return Flux.error((Throwable)new RuntimeException("Failed to send message: " + String.valueOf(responseEvent)));
            }).flatMap(jsonRpcMessage -> (Publisher)handler.apply(Mono.just((Object)jsonRpcMessage))).onErrorComplete(t -> {
                if (!this.isClosing) {
                    logger.warn("SSE stream observed an error", t);
                    sink.error(t);
                }
                return true;
            }).doFinally(s -> {
                Disposable ref = this.sseSubscription.getAndSet(null);
                if (ref != null && !ref.isDisposed()) {
                    ref.dispose();
                }
            }).contextWrite(sink.contextView()).subscribe();
            this.sseSubscription.set(connection);
        }));
    }

    @Override
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
        return this.messageEndpointSink.asMono().flatMap(messageEndpointUri -> {
            if (this.isClosing) {
                return Mono.empty();
            }
            return this.serializeMessage(message).flatMap(body -> this.sendHttpPost((String)messageEndpointUri, (String)body).handle((response, sink) -> {
                if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 && response.statusCode() != 206) {
                    sink.error((Throwable)new RuntimeException("Sending message failed with a non-OK HTTP code: " + response.statusCode() + " - " + (String)response.body()));
                } else {
                    sink.next(response);
                    sink.complete();
                }
            })).doOnError(error -> {
                if (!this.isClosing) {
                    logger.error("Error sending message: {}", (Object)error.getMessage());
                }
            });
        }).then();
    }

    private Mono<String> serializeMessage(McpSchema.JSONRPCMessage message) {
        return Mono.defer(() -> {
            try {
                return Mono.just((Object)this.jsonMapper.writeValueAsString(message));
            }
            catch (IOException e) {
                return Mono.error((Throwable)new McpTransportException("Failed to serialize message", e));
            }
        });
    }

    private Mono<HttpResponse<String>> sendHttpPost(String endpoint, String body) {
        URI requestUri = Utils.resolveUri(this.baseUri, endpoint);
        return Mono.deferContextual(ctx -> {
            HttpRequest.Builder builder = this.requestBuilder.copy().uri(requestUri).header("Content-Type", "application/json").header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION).POST(HttpRequest.BodyPublishers.ofString(body));
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body, transportContext));
        }).flatMap(customizedBuilder -> {
            HttpRequest request = customizedBuilder.build();
            return Mono.fromFuture(this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()));
        });
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            Disposable subscription = this.sseSubscription.get();
            if (subscription != null && !subscription.isDisposed()) {
                subscription.dispose();
            }
        });
    }

    @Override
    public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
        return this.jsonMapper.convertValue(data, typeRef);
    }

    public static class Builder {
        private String baseUri;
        private String sseEndpoint = "/sse";
        private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1);
        private McpJsonMapper jsonMapper;
        private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
        private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP;
        private Duration connectTimeout = Duration.ofSeconds(10L);

        Builder() {
        }

        @Deprecated(forRemoval=true)
        public Builder(String baseUri) {
            Assert.hasText(baseUri, "baseUri must not be empty");
            this.baseUri = baseUri;
        }

        Builder baseUri(String baseUri) {
            Assert.hasText(baseUri, "baseUri must not be empty");
            this.baseUri = baseUri;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public Builder clientBuilder(HttpClient.Builder clientBuilder) {
            Assert.notNull(clientBuilder, "clientBuilder must not be null");
            this.clientBuilder = clientBuilder;
            return this;
        }

        public Builder customizeClient(Consumer<HttpClient.Builder> clientCustomizer) {
            Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
            clientCustomizer.accept(this.clientBuilder);
            return this;
        }

        public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
            Assert.notNull(requestBuilder, "requestBuilder must not be null");
            this.requestBuilder = requestBuilder;
            return this;
        }

        public Builder customizeRequest(Consumer<HttpRequest.Builder> requestCustomizer) {
            Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
            requestCustomizer.accept(this.requestBuilder);
            return this;
        }

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull(jsonMapper, "jsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) {
            this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer);
            return this;
        }

        public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) {
            this.httpRequestCustomizer = asyncHttpRequestCustomizer;
            return this;
        }

        public Builder connectTimeout(Duration connectTimeout) {
            Assert.notNull(connectTimeout, "connectTimeout must not be null");
            this.connectTimeout = connectTimeout;
            return this;
        }

        public HttpClientSseClientTransport build() {
            HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build();
            return new HttpClientSseClientTransport(httpClient, this.requestBuilder, this.baseUri, this.sseEndpoint, this.jsonMapper == null ? McpJsonDefaults.getMapper() : this.jsonMapper, this.httpRequestCustomizer);
        }
    }
}

