/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.helper;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.transport.client.Client;

public class MemoryEmbeddingHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(MemoryEmbeddingHelper.class);
    private final Client client;
    private final MLModelManager mlModelManager;

    @Inject
    public MemoryEmbeddingHelper(Client client, MLModelManager mlModelManager) {
        this.client = client;
        this.mlModelManager = mlModelManager;
    }

    public void generateEmbeddingsForMultipleTexts(List<String> texts, MemoryStorageConfig storageConfig, ActionListener<List<Object>> listener) {
        if (texts.isEmpty()) {
            listener.onResponse(new ArrayList());
            return;
        }
        this.generateEmbeddingsInternal(texts, storageConfig, listener);
    }

    private void generateEmbeddingsInternal(List<String> texts, MemoryStorageConfig storageConfig, ActionListener<List<Object>> listener) {
        String embeddingModelId = storageConfig.getEmbeddingModelId();
        FunctionName embeddingModelType = storageConfig.getEmbeddingModelType();
        this.validateEmbeddingModelState(embeddingModelId, embeddingModelType, (ActionListener<Boolean>)ActionListener.wrap(isValid -> {
            MLInput mlInput = MLInput.builder().algorithm(embeddingModelType).inputDataset((MLInputDataset)TextDocsInputDataSet.builder().docs(texts).build()).build();
            MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().modelId(embeddingModelId).mlInput(mlInput).build();
            this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)predictionRequest, ActionListener.wrap(response -> {
                try {
                    MLOutput mlOutput = response.getOutput();
                    if (mlOutput instanceof ModelTensorOutput) {
                        ModelTensorOutput tensorOutput = (ModelTensorOutput)mlOutput;
                        ArrayList<Object> embeddings = new ArrayList<Object>();
                        if (tensorOutput.getMlModelOutputs() != null) {
                            for (ModelTensors modelTensors : tensorOutput.getMlModelOutputs()) {
                                Object embedding = null;
                                if (embeddingModelType == FunctionName.TEXT_EMBEDDING) {
                                    embedding = this.extractDenseEmbeddingFromModelTensors(modelTensors);
                                } else if (embeddingModelType == FunctionName.SPARSE_ENCODING) {
                                    embedding = this.extractSparseEmbeddingFromModelTensors(modelTensors);
                                }
                                embeddings.add(embedding);
                            }
                        }
                        listener.onResponse(embeddings);
                    } else {
                        log.error("Unexpected ML output type: {}", (Object)mlOutput.getClass().getName());
                        listener.onFailure((Exception)new IllegalStateException("Unexpected ML output type: " + mlOutput.getClass().getName()));
                    }
                }
                catch (Exception e) {
                    log.error("Failed to extract embeddings from ML output", (Throwable)e);
                    listener.onFailure((Exception)new IllegalStateException("Failed to extract embeddings from ML output", e));
                }
            }, e -> {
                log.error("Failed to generate embeddings", (Throwable)e);
                listener.onFailure(e);
            }));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void generateEmbedding(String text, MemoryStorageConfig storageConfig, ActionListener<Object> listener) {
        if (storageConfig == null || !storageConfig.isSemanticStorageEnabled()) {
            listener.onResponse(null);
            return;
        }
        String embeddingModelId = storageConfig.getEmbeddingModelId();
        FunctionName embeddingModelType = storageConfig.getEmbeddingModelType();
        if (embeddingModelId == null || embeddingModelType == null) {
            log.error("Embedding model configuration is missing");
            listener.onResponse(null);
            return;
        }
        this.generateEmbeddingsInternal(Arrays.asList(text), storageConfig, (ActionListener<List<Object>>)ActionListener.wrap(embeddings -> {
            Object embedding = embeddings != null && !embeddings.isEmpty() ? embeddings.get(0) : null;
            listener.onResponse(embedding);
        }, e -> {
            log.error("Failed to validate embedding model state", (Throwable)e);
            listener.onResponse(null);
        }));
    }

    public void validateEmbeddingModelState(String modelId, FunctionName modelType, ActionListener<Boolean> listener) {
        if (modelType == FunctionName.REMOTE) {
            listener.onResponse((Object)true);
            return;
        }
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore((ActionListener)ActionListener.wrap(model -> {
                MLModelState modelState = model.getModelState();
                if (model.getAlgorithm() != FunctionName.REMOTE && modelState != MLModelState.DEPLOYED && modelState != MLModelState.PARTIALLY_DEPLOYED) {
                    listener.onFailure((Exception)new IllegalStateException(String.format("Embedding model must be in DEPLOYED state, current state: %s", modelState)));
                } else {
                    listener.onResponse((Object)true);
                }
            }, e -> {
                log.error("Failed to get embedding model: {}", (Object)modelId, e);
                listener.onFailure((Exception)new IllegalStateException("Failed to validate embedding model state", (Throwable)e));
            }), () -> ((ThreadContext.StoredContext)context).restore());
            this.mlModelManager.getModel(modelId, (ActionListener<MLModel>)wrappedListener);
        }
    }

    private Object extractDenseEmbeddingFromModelTensors(ModelTensors modelTensors) {
        if (modelTensors.getMlModelTensors() == null || modelTensors.getMlModelTensors().isEmpty()) {
            return null;
        }
        for (ModelTensor tensor : modelTensors.getMlModelTensors()) {
            if (!"sentence_embedding".equals(tensor.getName()) || tensor.getData() == null) continue;
            Number[] data = tensor.getData();
            float[] floatData = new float[data.length];
            for (int i = 0; i < data.length; ++i) {
                floatData[i] = data[i].floatValue();
            }
            return floatData;
        }
        return null;
    }

    private Object extractSparseEmbeddingFromModelTensors(ModelTensors modelTensors) {
        if (modelTensors.getMlModelTensors() == null || modelTensors.getMlModelTensors().isEmpty()) {
            return null;
        }
        for (ModelTensor tensor : modelTensors.getMlModelTensors()) {
            List responseList;
            Map dataMap = tensor.getDataAsMap();
            if (dataMap == null) continue;
            if (dataMap.containsKey("response") && dataMap.get("response") instanceof List && !(responseList = (List)dataMap.get("response")).isEmpty() && responseList.get(0) instanceof Map) {
                return responseList.get(0);
            }
            return dataMap;
        }
        return null;
    }
}

