/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.assignment.assignors;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.assignment.ApplicationState;
import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment;
import org.apache.kafka.streams.processor.assignment.KafkaStreamsState;
import org.apache.kafka.streams.processor.assignment.ProcessId;
import org.apache.kafka.streams.processor.assignment.TaskAssignmentUtils;
import org.apache.kafka.streams.processor.assignment.TaskAssignor;
import org.apache.kafka.streams.processor.assignment.TaskInfo;
import org.apache.kafka.streams.processor.assignment.TaskTopicPartition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StickyTaskAssignor
implements TaskAssignor {
    private static final Logger LOG = LoggerFactory.getLogger(StickyTaskAssignor.class);
    public static final int DEFAULT_STICKY_TRAFFIC_COST = 1;
    public static final int DEFAULT_STICKY_NON_OVERLAP_COST = 10;
    private final boolean mustPreserveActiveTaskAssignment;

    public StickyTaskAssignor() {
        this(false);
    }

    public StickyTaskAssignor(boolean mustPreserveActiveTaskAssignment) {
        this.mustPreserveActiveTaskAssignment = mustPreserveActiveTaskAssignment;
    }

    @Override
    public TaskAssignor.TaskAssignment assign(ApplicationState applicationState) {
        Map<ProcessId, KafkaStreamsState> clients = applicationState.kafkaStreamsStates(false);
        Map<TaskId, ProcessId> previousActiveAssignment = StickyTaskAssignor.mapPreviousActiveTasks(clients);
        Map<TaskId, Set<ProcessId>> previousStandbyAssignment = StickyTaskAssignor.mapPreviousStandbyTasks(clients);
        AssignmentState assignmentState = new AssignmentState(applicationState, clients, previousActiveAssignment, previousStandbyAssignment);
        StickyTaskAssignor.assignActive(applicationState, clients.values(), assignmentState, this.mustPreserveActiveTaskAssignment);
        this.optimizeActive(applicationState, assignmentState);
        StickyTaskAssignor.assignStandby(applicationState, assignmentState);
        this.optimizeStandby(applicationState, assignmentState);
        Map<ProcessId, KafkaStreamsAssignment> finalAssignments = assignmentState.newAssignments;
        if (this.mustPreserveActiveTaskAssignment && !finalAssignments.isEmpty()) {
            ProcessId clientId = finalAssignments.entrySet().iterator().next().getKey();
            KafkaStreamsAssignment previousAssignment = finalAssignments.get(clientId);
            finalAssignments.put(clientId, previousAssignment.withFollowupRebalance(Instant.ofEpochMilli(0L)));
        }
        return new TaskAssignor.TaskAssignment(finalAssignments.values());
    }

    private void optimizeActive(ApplicationState applicationState, AssignmentState assignmentState) {
        if (this.mustPreserveActiveTaskAssignment) {
            return;
        }
        Map<ProcessId, KafkaStreamsAssignment> currentAssignments = assignmentState.newAssignments;
        TaskAssignmentUtils.RackAwareOptimizationParams statefulTaskParams = TaskAssignmentUtils.RackAwareOptimizationParams.of(applicationState).withTrafficCostOverride(applicationState.assignmentConfigs().rackAwareTrafficCost().orElse(1)).withNonOverlapCostOverride(applicationState.assignmentConfigs().rackAwareNonOverlapCost().orElse(10)).forStatefulTasks();
        TaskAssignmentUtils.optimizeRackAwareActiveTasks(statefulTaskParams, currentAssignments);
        TaskAssignmentUtils.optimizeRackAwareActiveTasks(TaskAssignmentUtils.RackAwareOptimizationParams.of(applicationState).forStatelessTasks().withTrafficCostOverride(1).withNonOverlapCostOverride(0), currentAssignments);
        assignmentState.processOptimizedAssignments(currentAssignments);
    }

    private void optimizeStandby(ApplicationState applicationState, AssignmentState assignmentState) {
        if (applicationState.assignmentConfigs().numStandbyReplicas() <= 0) {
            return;
        }
        if (this.mustPreserveActiveTaskAssignment) {
            return;
        }
        Map<ProcessId, KafkaStreamsAssignment> assignments = assignmentState.newAssignments;
        TaskAssignmentUtils.RackAwareOptimizationParams optimizationParams = TaskAssignmentUtils.RackAwareOptimizationParams.of(applicationState).withTrafficCostOverride(applicationState.assignmentConfigs().rackAwareTrafficCost().orElse(1)).withNonOverlapCostOverride(applicationState.assignmentConfigs().rackAwareNonOverlapCost().orElse(10));
        TaskAssignmentUtils.optimizeRackAwareStandbyTasks(optimizationParams, assignments);
        assignmentState.processOptimizedAssignments(assignments);
    }

    private static void assignActive(ApplicationState applicationState, Collection<KafkaStreamsState> clients, AssignmentState assignmentState, boolean mustPreserveActiveTaskAssignment) {
        int totalCapacity = StickyTaskAssignor.computeTotalProcessingThreads(clients);
        Set<TaskId> allTaskIds = applicationState.allTasks().keySet();
        int taskCount = allTaskIds.size();
        int activeTasksPerThread = taskCount / totalCapacity;
        HashSet<TaskId> unassigned = new HashSet<TaskId>(allTaskIds);
        for (TaskId taskId : assignmentState.previousActiveAssignment.keySet()) {
            ProcessId previousClientForTask = assignmentState.previousActiveAssignment.get(taskId);
            if (!allTaskIds.contains(taskId) || !mustPreserveActiveTaskAssignment && !assignmentState.hasRoomForActiveTask(previousClientForTask, activeTasksPerThread)) continue;
            assignmentState.finalizeAssignment(taskId, previousClientForTask, KafkaStreamsAssignment.AssignedTask.Type.ACTIVE);
            unassigned.remove(taskId);
        }
        Iterator iterator = unassigned.iterator();
        block1: while (iterator.hasNext()) {
            TaskId taskId;
            taskId = (TaskId)iterator.next();
            Set previousClientsForStandbyTask = assignmentState.previousStandbyAssignment.getOrDefault(taskId, new HashSet());
            for (ProcessId client : previousClientsForStandbyTask) {
                if (!assignmentState.hasRoomForActiveTask(client, activeTasksPerThread)) continue;
                assignmentState.finalizeAssignment(taskId, client, KafkaStreamsAssignment.AssignedTask.Type.ACTIVE);
                iterator.remove();
                continue block1;
            }
        }
        ArrayList<TaskId> sortedTasks = new ArrayList<TaskId>(unassigned);
        Collections.sort(sortedTasks);
        for (TaskId taskId : sortedTasks) {
            Set<ProcessId> candidateClients = clients.stream().map(KafkaStreamsState::processId).collect(Collectors.toSet());
            ProcessId bestClient = assignmentState.findBestClientForTask(taskId, candidateClients);
            assignmentState.finalizeAssignment(taskId, bestClient, KafkaStreamsAssignment.AssignedTask.Type.ACTIVE);
        }
    }

    private static void assignStandby(ApplicationState applicationState, AssignmentState assignmentState) {
        Set statefulTasks = applicationState.allTasks().values().stream().filter(taskInfo -> taskInfo.topicPartitions().stream().anyMatch(TaskTopicPartition::isChangelog)).collect(Collectors.toSet());
        int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas();
        block0: for (TaskInfo task : statefulTasks) {
            for (int i = 0; i < numStandbyReplicas; ++i) {
                Set<ProcessId> candidateClients = assignmentState.findClientsWithoutAssignedTask(task.id());
                if (candidateClients.isEmpty()) {
                    LOG.warn("Unable to assign {} of {} standby tasks for task [{}]. There is not enough available capacity. You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.", new Object[]{numStandbyReplicas - i, numStandbyReplicas, task.id()});
                    continue block0;
                }
                ProcessId bestClient = assignmentState.findBestClientForTask(task.id(), candidateClients);
                assignmentState.finalizeAssignment(task.id(), bestClient, KafkaStreamsAssignment.AssignedTask.Type.STANDBY);
            }
        }
    }

    private static Map<TaskId, ProcessId> mapPreviousActiveTasks(Map<ProcessId, KafkaStreamsState> clients) {
        HashMap<TaskId, ProcessId> previousActiveTasks = new HashMap<TaskId, ProcessId>();
        for (KafkaStreamsState client : clients.values()) {
            for (TaskId taskId : client.previousActiveTasks()) {
                previousActiveTasks.put(taskId, client.processId());
            }
        }
        return previousActiveTasks;
    }

    private static Map<TaskId, Set<ProcessId>> mapPreviousStandbyTasks(Map<ProcessId, KafkaStreamsState> clients) {
        HashMap<TaskId, Set<ProcessId>> previousStandbyTasks = new HashMap<TaskId, Set<ProcessId>>();
        for (KafkaStreamsState client : clients.values()) {
            for (TaskId taskId : client.previousStandbyTasks()) {
                previousStandbyTasks.computeIfAbsent(taskId, k -> new HashSet());
                ((Set)previousStandbyTasks.get(taskId)).add(client.processId());
            }
        }
        return previousStandbyTasks;
    }

    private static int computeTotalProcessingThreads(Collection<KafkaStreamsState> clients) {
        int count = 0;
        for (KafkaStreamsState client : clients) {
            count += client.numProcessingThreads();
        }
        return count;
    }

    private static class AssignmentState {
        private final Map<ProcessId, KafkaStreamsState> clients;
        private final Map<TaskId, ProcessId> previousActiveAssignment;
        private final Map<TaskId, Set<ProcessId>> previousStandbyAssignment;
        private final TaskPairs taskPairs;
        private Map<TaskId, Set<ProcessId>> newTaskLocations;
        private Map<ProcessId, KafkaStreamsAssignment> newAssignments;

        private AssignmentState(ApplicationState applicationState, Map<ProcessId, KafkaStreamsState> clients, Map<TaskId, ProcessId> previousActiveAssignment, Map<TaskId, Set<ProcessId>> previousStandbyAssignment) {
            this.clients = clients;
            this.previousActiveAssignment = Collections.unmodifiableMap(previousActiveAssignment);
            this.previousStandbyAssignment = Collections.unmodifiableMap(previousStandbyAssignment);
            int taskCount = applicationState.allTasks().size();
            int maxPairs = taskCount * (taskCount - 1) / 2;
            this.taskPairs = new TaskPairs(maxPairs);
            this.newTaskLocations = previousActiveAssignment.keySet().stream().collect(Collectors.toMap(Function.identity(), taskId -> new HashSet()));
            this.newAssignments = clients.values().stream().collect(Collectors.toMap(KafkaStreamsState::processId, state -> KafkaStreamsAssignment.of(state.processId(), new HashSet<KafkaStreamsAssignment.AssignedTask>())));
        }

        private void finalizeAssignment(TaskId taskId, ProcessId client, KafkaStreamsAssignment.AssignedTask.Type type) {
            Set<TaskId> newAssignmentsForClient = this.newAssignments.get(client).tasks().keySet();
            this.taskPairs.addPairs(taskId, newAssignmentsForClient);
            this.newAssignments.get(client).assignTask(new KafkaStreamsAssignment.AssignedTask(taskId, type));
            this.newTaskLocations.computeIfAbsent(taskId, k -> new HashSet()).add(client);
        }

        private void processOptimizedAssignments(Map<ProcessId, KafkaStreamsAssignment> optimizedAssignments) {
            HashMap<TaskId, Set<ProcessId>> newTaskLocations = new HashMap<TaskId, Set<ProcessId>>();
            for (Map.Entry<ProcessId, KafkaStreamsAssignment> entry : optimizedAssignments.entrySet()) {
                ProcessId processId = entry.getKey();
                HashSet<KafkaStreamsAssignment.AssignedTask> assignedTasks = new HashSet<KafkaStreamsAssignment.AssignedTask>(optimizedAssignments.get(processId).tasks().values());
                for (KafkaStreamsAssignment.AssignedTask task : assignedTasks) {
                    newTaskLocations.computeIfAbsent(task.id(), k -> new HashSet()).add(processId);
                }
            }
            this.newTaskLocations = newTaskLocations;
            this.newAssignments = optimizedAssignments;
        }

        private boolean hasRoomForActiveTask(ProcessId processId, int activeTasksPerThread) {
            int capacity = this.clients.get(processId).numProcessingThreads();
            int newActiveTaskCount = this.newAssignments.computeIfAbsent(processId, k -> KafkaStreamsAssignment.of(processId, new HashSet<KafkaStreamsAssignment.AssignedTask>())).tasks().values().stream().filter(assignedTask -> assignedTask.type() == KafkaStreamsAssignment.AssignedTask.Type.ACTIVE).collect(Collectors.toSet()).size();
            return newActiveTaskCount < capacity * activeTasksPerThread;
        }

        private ProcessId findBestClientForTask(TaskId taskId, Set<ProcessId> clientsWithin) {
            if (clientsWithin.size() == 1) {
                return clientsWithin.iterator().next();
            }
            ProcessId previousClient = this.findLeastLoadedClientWithPreviousActiveOrStandbyTask(taskId, clientsWithin);
            if (previousClient == null) {
                return this.findLeastLoadedClient(taskId, clientsWithin);
            }
            if (this.shouldBalanceLoad(previousClient)) {
                ProcessId standby = this.findLeastLoadedClientWithPreviousStandbyTask(taskId, clientsWithin);
                if (standby == null || this.shouldBalanceLoad(standby)) {
                    return this.findLeastLoadedClient(taskId, clientsWithin);
                }
                return standby;
            }
            return previousClient;
        }

        private Set<ProcessId> findClientsWithoutAssignedTask(TaskId taskId) {
            Set<ProcessId> unavailableClients = this.newTaskLocations.get(taskId);
            return this.clients.values().stream().map(KafkaStreamsState::processId).filter(o -> !unavailableClients.contains(o)).collect(Collectors.toSet());
        }

        private double clientLoad(ProcessId processId) {
            int capacity = this.clients.get(processId).numProcessingThreads();
            double totalTaskCount = this.newAssignments.get(processId).tasks().size();
            return totalTaskCount / (double)capacity;
        }

        private ProcessId findLeastLoadedClient(TaskId taskId, Set<ProcessId> clientIds) {
            double thisClientLoad;
            ProcessId leastLoaded = null;
            for (ProcessId processId : clientIds) {
                Set<TaskId> assignedTasks;
                thisClientLoad = this.clientLoad(processId);
                if (thisClientLoad == 0.0) {
                    return processId;
                }
                if (leastLoaded != null && !(thisClientLoad < this.clientLoad(leastLoaded)) || !this.taskPairs.hasNewPair(taskId, assignedTasks = this.newAssignments.get(processId).tasks().values().stream().map(KafkaStreamsAssignment.AssignedTask::id).collect(Collectors.toSet()))) continue;
                leastLoaded = processId;
            }
            if (leastLoaded != null) {
                return leastLoaded;
            }
            for (ProcessId processId : clientIds) {
                thisClientLoad = this.clientLoad(processId);
                if (leastLoaded != null && !(thisClientLoad < this.clientLoad(leastLoaded))) continue;
                leastLoaded = processId;
            }
            return leastLoaded;
        }

        private ProcessId findLeastLoadedClientWithPreviousActiveOrStandbyTask(TaskId taskId, Set<ProcessId> clientsWithin) {
            ProcessId previous = this.previousActiveAssignment.get(taskId);
            if (previous != null && clientsWithin.contains(previous)) {
                return previous;
            }
            return this.findLeastLoadedClientWithPreviousStandbyTask(taskId, clientsWithin);
        }

        private ProcessId findLeastLoadedClientWithPreviousStandbyTask(TaskId taskId, Set<ProcessId> clientsWithin) {
            Set ids = this.previousStandbyAssignment.getOrDefault(taskId, new HashSet());
            HashSet<ProcessId> constrainTo = new HashSet<ProcessId>(ids);
            constrainTo.retainAll(clientsWithin);
            return this.findLeastLoadedClient(taskId, constrainTo);
        }

        private boolean shouldBalanceLoad(ProcessId client) {
            double thisClientLoad = this.clientLoad(client);
            if (thisClientLoad < 1.0) {
                return false;
            }
            for (ProcessId otherClient : this.clients.keySet()) {
                if (!(this.clientLoad(otherClient) < thisClientLoad)) continue;
                return true;
            }
            return false;
        }
    }

    private static class TaskPair {
        private final TaskId task1;
        private final TaskId task2;

        TaskPair(TaskId task1, TaskId task2) {
            this.task1 = task1;
            this.task2 = task2;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            TaskPair pair = (TaskPair)o;
            return Objects.equals(this.task1, pair.task1) && Objects.equals(this.task2, pair.task2);
        }

        public int hashCode() {
            return Objects.hash(this.task1, this.task2);
        }
    }

    private static class TaskPairs {
        private final Set<TaskPair> pairs;
        private final int maxPairs;

        public TaskPairs(int maxPairs) {
            this.maxPairs = maxPairs;
            this.pairs = new HashSet<TaskPair>(maxPairs);
        }

        public boolean hasNewPair(TaskId task1, Set<TaskId> taskIds) {
            if (this.pairs.size() == this.maxPairs) {
                return false;
            }
            for (TaskId taskId : taskIds) {
                if (this.pairs.contains(this.pair(task1, taskId))) continue;
                return true;
            }
            return false;
        }

        public void addPairs(TaskId taskId, Set<TaskId> assigned) {
            for (TaskId id : assigned) {
                this.pairs.add(this.pair(id, taskId));
            }
        }

        public TaskPair pair(TaskId task1, TaskId task2) {
            if (task1.compareTo(task2) < 0) {
                return new TaskPair(task1, task2);
            }
            return new TaskPair(task2, task1);
        }
    }
}

