/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.udf;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.udf.api.UDTF;
import org.apache.iotdb.udf.api.access.Row;
import org.apache.iotdb.udf.api.collector.PointCollector;
import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
import org.apache.iotdb.udf.api.customizer.strategy.AccessStrategy;
import org.apache.iotdb.udf.api.customizer.strategy.RowByRowAccessStrategy;
import org.apache.iotdb.udf.api.type.Type;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.read.common.block.column.TsBlockSerde;

public class UDTFForecast
implements UDTF {
    private static final TsBlockSerde serde = new TsBlockSerde();
    private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance();
    private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810);
    private String model_id;
    private int maxInputLength;
    private int outputLength;
    private long outputStartTime;
    private long outputInterval;
    private boolean keepInput;
    Map<String, String> options;
    List<Type> types;
    private LinkedList<Row> inputRows;
    private TsBlockBuilder inputTsBlockBuilder;
    private final IModelFetcher modelFetcher = ModelFetcher.getInstance();
    private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<Type>();
    private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
    private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
    private static final int DEFAULT_OUTPUT_LENGTH = 96;
    private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
    public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
    private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
    public static final long DEFAULT_OUTPUT_INTERVAL = 0L;
    private static final String KEEP_INPUT_PARAMETER_NAME = "PRESERVE_INPUT";
    private static final Boolean DEFAULT_KEEP_INPUT;
    private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
    private static final String DEFAULT_OPTIONS = "";

    private void checkType() {
        for (Type type : this.types) {
            if (ALLOWED_INPUT_TYPES.contains(type)) continue;
            throw new IllegalArgumentException(String.format("Input data type %s is not supported, only %s are allowed.", type, ALLOWED_INPUT_TYPES));
        }
    }

    public void beforeStart(UDFParameters parameters, UDTFConfigurations configurations) throws Exception {
        this.types = parameters.getDataTypes();
        this.checkType();
        configurations.setAccessStrategy((AccessStrategy)new RowByRowAccessStrategy()).setOutputDataType(Type.DOUBLE);
        this.model_id = parameters.getString(MODEL_ID_PARAMETER_NAME);
        if (this.model_id == null || this.model_id.isEmpty()) {
            throw new IllegalArgumentException("MODEL_ID parameter must be provided and cannot be empty.");
        }
        ModelInferenceDescriptor descriptor = this.modelFetcher.fetchModel(this.model_id);
        this.targetAINode = descriptor.getTargetAINode();
        this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, 0L);
        this.outputLength = parameters.getIntOrDefault(OUTPUT_LENGTH_PARAMETER_NAME, 96);
        this.outputStartTime = parameters.getLongOrDefault(OUTPUT_START_TIME, Long.MIN_VALUE);
        this.keepInput = parameters.getBooleanOrDefault(KEEP_INPUT_PARAMETER_NAME, DEFAULT_KEEP_INPUT.booleanValue());
        this.options = Arrays.stream(parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(",")).map(s -> s.split("=")).filter(arr -> ((String[])arr).length == 2 && !arr[0].isEmpty()).collect(Collectors.toMap(arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2));
        this.inputRows = new LinkedList();
        ArrayList<TSDataType> tsDataTypeList = new ArrayList<TSDataType>(this.types.size() - 1);
        for (int i = 0; i < this.types.size(); ++i) {
            tsDataTypeList.add(TSDataType.DOUBLE);
        }
        this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
    }

    private void setByType(Row row, PointCollector collector) throws IOException {
        block6: for (int i = 0; i < row.size(); ++i) {
            switch (this.types.get(i)) {
                case INT32: {
                    collector.putInt(row.getTime(), row.getInt(i));
                    continue block6;
                }
                case INT64: {
                    collector.putLong(row.getTime(), row.getLong(i));
                    continue block6;
                }
                case FLOAT: {
                    collector.putFloat(row.getTime(), row.getFloat(i));
                    continue block6;
                }
                case DOUBLE: {
                    collector.putDouble(row.getTime(), row.getDouble(i));
                    continue block6;
                }
                default: {
                    throw new IllegalArgumentException(String.format("Unsupported data type %s", this.types.get(i + 1)));
                }
            }
        }
    }

    private void setByType(Row row, TsBlockBuilder tsBlockBuilder) throws IOException {
        block6: for (int i = 0; i < row.size(); ++i) {
            if (row.isNull(i)) {
                tsBlockBuilder.getColumnBuilder(i).appendNull();
                continue;
            }
            switch (this.types.get(i)) {
                case INT32: {
                    tsBlockBuilder.getColumnBuilder(i).writeInt(row.getInt(i));
                    continue block6;
                }
                case INT64: {
                    tsBlockBuilder.getColumnBuilder(i).writeLong(row.getLong(i));
                    continue block6;
                }
                case FLOAT: {
                    tsBlockBuilder.getColumnBuilder(i).writeFloat(row.getFloat(i));
                    continue block6;
                }
                case DOUBLE: {
                    tsBlockBuilder.getColumnBuilder(i).writeDouble(row.getDouble(i));
                    continue block6;
                }
                default: {
                    throw new IllegalArgumentException(String.format("Unsupported data type %s", this.types.get(i + 1)));
                }
            }
        }
    }

    public void transform(Row row, PointCollector collector) throws Exception {
        if (this.keepInput) {
            this.setByType(row, collector);
        }
        if (this.maxInputLength != 0 && this.inputRows.size() >= this.maxInputLength) {
            this.inputRows.removeFirst();
        }
        this.inputRows.add(row);
    }

    private TsBlock forecast() throws Exception {
        TForecastResp resp;
        while (!this.inputRows.isEmpty()) {
            Row row = this.inputRows.removeFirst();
            this.inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getTime());
            this.setByType(row, this.inputTsBlockBuilder);
            this.inputTsBlockBuilder.declarePosition();
        }
        TsBlock inputData = this.inputTsBlockBuilder.build();
        try (AINodeClient client = CLIENT_MANAGER.borrowClient(this.targetAINode);){
            resp = client.forecast(this.model_id, inputData, this.outputLength, this.options);
        }
        catch (Exception e) {
            throw new IoTDBRuntimeException(e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode());
        }
        if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
            throw new IoTDBRuntimeException(String.format("Forecast failed due to %d %s", resp.getStatus().getCode(), resp.getStatus().getMessage()), resp.getStatus().getCode());
        }
        return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
    }

    public void terminate(PointCollector collector) throws Exception {
        long inputEndTime;
        long inputStartTime = this.inputRows.get(0).getTime();
        if (inputStartTime > (inputEndTime = this.inputRows.get(this.inputRows.size() - 1).getTime())) {
            throw new IllegalArgumentException(String.format("input end time should never less than start time, start time is %s, end time is %s", inputStartTime, inputEndTime));
        }
        long interval = this.outputInterval;
        if (this.outputInterval <= 0L) {
            interval = (inputEndTime - inputStartTime) / (long)(this.inputRows.size() - 1);
        }
        long outputTime = this.outputStartTime == Long.MIN_VALUE ? inputEndTime + interval : this.outputStartTime;
        long[] outputTimes = new long[this.outputLength];
        for (int i = 0; i < this.outputLength; ++i) {
            outputTimes[i] = outputTime + interval * (long)i;
        }
        TsBlock forecastResult = this.forecast();
        if (forecastResult.getPositionCount() != this.outputLength) {
            throw new IllegalArgumentException(String.format("The forecast result length %d does not match the expected output length %d", forecastResult.getPositionCount(), this.outputLength));
        }
        if (forecastResult.getValueColumnCount() != 1) {
            throw new IllegalArgumentException(String.format("The forecast result should have only one value column, but got %d", forecastResult.getValueColumnCount()));
        }
        for (int i = 0; i < forecastResult.getPositionCount(); ++i) {
            collector.putDouble(outputTimes[i], forecastResult.getValueColumns()[0].getDouble(i));
        }
    }

    static {
        ALLOWED_INPUT_TYPES.add(Type.INT32);
        ALLOWED_INPUT_TYPES.add(Type.INT64);
        ALLOWED_INPUT_TYPES.add(Type.FLOAT);
        ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
        DEFAULT_KEEP_INPUT = Boolean.FALSE;
    }
}

