/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.kernel.controller;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import javax.annotation.Nullable;
import org.apache.druid.msq.exec.OutputChannelMode;
import org.apache.druid.msq.indexing.destination.MSQDestination;
import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
import org.apache.druid.msq.input.InputSpec;
import org.apache.druid.msq.input.InputSpecs;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig;
import org.apache.druid.msq.kernel.controller.StageGroup;

public class ControllerQueryKernelUtils {
    public static List<StageGroup> computeStageGroups(QueryDefinition queryDef, ControllerQueryKernelConfig config) {
        MSQDestination destination = config.getDestination();
        ArrayList<StageGroup> stageGroups = new ArrayList<StageGroup>();
        boolean useDurableStorage = config.isDurableStorage();
        Map<StageId, SortedSet<StageId>> inflow = ControllerQueryKernelUtils.computeStageInflowMap(queryDef);
        Map<StageId, SortedSet<StageId>> outflow = ControllerQueryKernelUtils.computeStageOutflowMap(queryDef);
        HashSet<StageId> stagesRun = new HashSet<StageId>();
        while (stagesRun.size() < queryDef.getStageDefinitions().size()) {
            StageGroup priorGroup;
            StageId stageId2;
            boolean didRun;
            do {
                didRun = false;
                for (StageId stageId2 : ImmutableList.copyOf(inflow.keySet())) {
                    if (stagesRun.contains(stageId2) || !inflow.get(stageId2).isEmpty() || ControllerQueryKernelUtils.canUseMemoryOutput(queryDef, stageId2.getStageNumber(), config, outflow)) continue;
                    stagesRun.add(stageId2);
                    stageGroups.add(new StageGroup(Collections.singletonList(stageId2), ControllerQueryKernelUtils.getOutputChannelMode(queryDef, stageId2.getStageNumber(), destination.toSelectDestination(), useDurableStorage, false)));
                    ControllerQueryKernelUtils.removeStageFlow(stageId2, inflow, outflow);
                    didRun = true;
                }
            } while (didRun);
            StageId currentStageId = null;
            stageId2 = ImmutableList.copyOf(inflow.keySet()).iterator();
            while (stageId2.hasNext()) {
                StageId stageId3 = (StageId)stageId2.next();
                if (stagesRun.contains(stageId3) || !inflow.get(stageId3).isEmpty() || !ControllerQueryKernelUtils.canUseMemoryOutput(queryDef, stageId3.getStageNumber(), config, outflow)) continue;
                currentStageId = stageId3;
                break;
            }
            if (currentStageId == null) continue;
            ArrayList<StageId> currentStageGroup = new ArrayList<StageId>();
            int maxStageGroupSize = stageGroups.isEmpty() ? config.getMaxConcurrentStages() : ((priorGroup = (StageGroup)stageGroups.get(stageGroups.size() - 1)).lastStageOutputChannelMode() == OutputChannelMode.MEMORY ? config.getMaxConcurrentStages() - priorGroup.size() : config.getMaxConcurrentStages());
            OutputChannelMode currentOutputChannelMode = null;
            while (currentStageId != null) {
                boolean canUseMemoryOuput = ControllerQueryKernelUtils.canUseMemoryOutput(queryDef, currentStageId.getStageNumber(), config, outflow);
                Set currentOutflow = outflow.get(currentStageId);
                int maxStageGroupSizeAllowingForDownstreamConsumer = queryDef.getStageDefinition(currentStageId).doesSortDuringShuffle() ? config.getMaxConcurrentStages() - 1 : maxStageGroupSize - 1;
                currentOutputChannelMode = ControllerQueryKernelUtils.getOutputChannelMode(queryDef, currentStageId.getStageNumber(), config.getDestination().toSelectDestination(), config.isDurableStorage(), !(!canUseMemoryOuput || !currentOutflow.isEmpty() && !Collections.singleton(currentStageId).equals(inflow.get(Iterables.getOnlyElement((Iterable)currentOutflow))) || !currentOutflow.isEmpty() && currentStageGroup.size() >= maxStageGroupSizeAllowingForDownstreamConsumer));
                currentStageGroup.add(currentStageId);
                if (currentOutflow.size() == 1 && currentStageGroup.size() < maxStageGroupSize && currentOutputChannelMode == OutputChannelMode.MEMORY && !queryDef.getStageDefinition(currentStageId).doesSortDuringShuffle()) {
                    currentStageId = (StageId)Iterables.getOnlyElement((Iterable)currentOutflow);
                    continue;
                }
                currentStageId = null;
            }
            stageGroups.add(new StageGroup(currentStageGroup, currentOutputChannelMode));
            for (StageId stageId4 : currentStageGroup) {
                stagesRun.add(stageId4);
                ControllerQueryKernelUtils.removeStageFlow(stageId4, inflow, outflow);
            }
        }
        return stageGroups;
    }

    public static Map<StageId, SortedSet<StageId>> computeStageInflowMap(QueryDefinition queryDefinition) {
        TreeMap<StageId, SortedSet<StageId>> retVal = new TreeMap<StageId, SortedSet<StageId>>();
        for (StageDefinition stageDef : queryDefinition.getStageDefinitions()) {
            StageId stageId = stageDef.getId();
            retVal.computeIfAbsent(stageId, ignored -> new TreeSet());
            IntIterator intIterator = queryDefinition.getStageDefinition(stageId).getInputStageNumbers().iterator();
            while (intIterator.hasNext()) {
                int inputStageNumber = (Integer)intIterator.next();
                StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber);
                retVal.computeIfAbsent(stageId, ignored -> new TreeSet()).add(inputStageId);
            }
        }
        return retVal;
    }

    public static Map<StageId, SortedSet<StageId>> computeStageOutflowMap(QueryDefinition queryDefinition) {
        TreeMap<StageId, SortedSet<StageId>> retVal = new TreeMap<StageId, SortedSet<StageId>>();
        for (StageDefinition stageDef : queryDefinition.getStageDefinitions()) {
            StageId stageId = stageDef.getId();
            retVal.computeIfAbsent(stageId, ignored -> new TreeSet());
            IntIterator intIterator = queryDefinition.getStageDefinition(stageId).getInputStageNumbers().iterator();
            while (intIterator.hasNext()) {
                int inputStageNumber = (Integer)intIterator.next();
                StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber);
                retVal.computeIfAbsent(inputStageId, ignored -> new TreeSet()).add(stageId);
            }
        }
        return retVal;
    }

    public static boolean canUseMemoryOutput(QueryDefinition queryDefinition, int stageNumber, ControllerQueryKernelConfig config, Map<StageId, SortedSet<StageId>> outflowMap) {
        if (config.isFaultTolerant()) {
            return false;
        }
        if (!config.isPipeline() || config.getMaxConcurrentStages() < 2) {
            return false;
        }
        StageId stageId = queryDefinition.getStageDefinition(stageNumber).getId();
        Set outflowStageIds = outflowMap.get(stageId);
        if (outflowStageIds.isEmpty()) {
            return true;
        }
        if (outflowStageIds.size() == 1) {
            StageDefinition outflowStageDef = queryDefinition.getStageDefinition((StageId)Iterables.getOnlyElement((Iterable)outflowStageIds));
            return stageId.equals(ControllerQueryKernelUtils.getOnlyNonBroadcastInputAsStageId(outflowStageDef));
        }
        return false;
    }

    public static OutputChannelMode getOutputChannelMode(QueryDefinition queryDef, int stageNumber, @Nullable MSQSelectDestination selectDestination, boolean durableStorage, boolean canStream) {
        boolean isFinalStage;
        boolean bl = isFinalStage = queryDef.getFinalStageDefinition().getStageNumber() == stageNumber;
        if (isFinalStage && selectDestination == MSQSelectDestination.DURABLESTORAGE) {
            return OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS;
        }
        if (canStream) {
            return OutputChannelMode.MEMORY;
        }
        if (durableStorage) {
            return OutputChannelMode.DURABLE_STORAGE_INTERMEDIATE;
        }
        return OutputChannelMode.LOCAL_STORAGE;
    }

    @Nullable
    public static StageId getOnlyNonBroadcastInputAsStageId(StageDefinition downstreamStageDef) {
        List<InputSpec> inputSpecs = downstreamStageDef.getInputSpecs();
        IntSet broadcastInputNumbers = downstreamStageDef.getBroadcastInputNumbers();
        if (inputSpecs.size() - broadcastInputNumbers.size() != 1) {
            return null;
        }
        for (int i = 0; i < inputSpecs.size(); ++i) {
            IntSet stageNumbers;
            if (broadcastInputNumbers.contains(i) || (stageNumbers = InputSpecs.getStageNumbers(Collections.singletonList(inputSpecs.get(i)))).size() != 1) continue;
            return new StageId(downstreamStageDef.getId().getQueryId(), stageNumbers.iterator().nextInt());
        }
        return null;
    }

    private static void removeStageFlow(StageId stageId, Map<StageId, SortedSet<StageId>> inflow, Map<StageId, SortedSet<StageId>> outflow) {
        for (StageId outStageId : outflow.get(stageId)) {
            inflow.get(outStageId).remove(stageId);
        }
        outflow.get(stageId).clear();
    }
}

