/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.quantizationservice;

import java.io.IOException;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.Version;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.quantizationservice.KNNVectorQuantizationTrainingRequest;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.quantizer.Quantizer;

public final class QuantizationService<T, R> {
    private static final QuantizationService<?, ?> INSTANCE = new QuantizationService();

    public static <T, R> QuantizationService<T, R> getInstance() {
        return INSTANCE;
    }

    public QuantizationState train(QuantizationParams quantizationParams, Supplier<KNNVectorValues<T>> knnVectorValuesSupplier, long liveDocs) throws IOException {
        KNNVectorQuantizationTrainingRequest<T> trainingRequest;
        Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams);
        if (quantizationParams instanceof ScalarQuantizationParams) {
            ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams)quantizationParams;
            trainingRequest = new KNNVectorQuantizationTrainingRequest<T>(knnVectorValuesSupplier, liveDocs, scalarQuantizationParams.isEnableRandomRotation());
        } else {
            trainingRequest = new KNNVectorQuantizationTrainingRequest<T>(knnVectorValuesSupplier, liveDocs);
        }
        return quantizer.train(trainingRequest);
    }

    public R quantize(QuantizationState quantizationState, T vector, QuantizationOutput<R> quantizationOutput) {
        Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationState.getQuantizationParams());
        quantizer.quantize(vector, quantizationState, quantizationOutput);
        return quantizationOutput.getQuantizedVector();
    }

    public void transformWithADC(QuantizationState quantizationState, T vector, SpaceType spaceType) {
        Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationState.getQuantizationParams());
        quantizer.transformWithADC(vector, quantizationState, spaceType);
    }

    public QuantizationParams getQuantizationParams(FieldInfo fieldInfo, Version luceneVersion) {
        QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo, luceneVersion);
        if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
            return ScalarQuantizationParams.builder().sqType(quantizationConfig.getQuantizationType()).enableRandomRotation(quantizationConfig.isEnableRandomRotation()).enableADC(quantizationConfig.isEnableADC()).build();
        }
        return null;
    }

    public VectorDataType getVectorDataTypeForTransfer(FieldInfo fieldInfo, Version luceneVersion) {
        QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo, luceneVersion);
        if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
            return VectorDataType.BINARY;
        }
        return null;
    }

    public QuantizationOutput<R> createQuantizationOutput(QuantizationParams quantizationParams) {
        if (quantizationParams instanceof ScalarQuantizationParams) {
            ScalarQuantizationParams scalarParams = (ScalarQuantizationParams)quantizationParams;
            return new BinaryQuantizationOutput(scalarParams.getSqType().getId());
        }
        throw new IllegalArgumentException("Unsupported quantization parameters: " + quantizationParams.getClass().getName());
    }

    @Generated
    private QuantizationService() {
    }
}

