Automatically balance leaders on failover in LeaderElector state machine.

Change-Id: I5c8cb5bd4b2cafcceb11fd3a09693d6dd34566d4
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorService.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorService.java
index b259cd1..de97575 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorService.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorService.java
@@ -30,6 +30,7 @@
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Objects;
 import com.google.common.base.Throwables;
+import com.google.common.collect.ComparisonChain;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -99,6 +100,7 @@
         }
         termCounters = reader.readObject(SERIALIZER::decode);
         elections = reader.readObject(SERIALIZER::decode);
+        elections.values().forEach(e -> e.elections = elections);
         logger().debug("Reinstated state machine from snapshot");
     }
 
@@ -150,6 +152,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.Run} commit.
+     *
      * @param commit commit entry
      * @return topic leader. If no previous leader existed this is the node that just entered the race.
      */
@@ -160,10 +163,11 @@
             Registration registration = new Registration(commit.value().nodeId(), commit.session().sessionId().id());
             elections.compute(topic, (k, v) -> {
                 if (v == null) {
-                    return new ElectionState(registration, termCounter(topic)::incrementAndGet);
+                    return new ElectionState(registration, termCounter(topic)::incrementAndGet, elections);
                 } else {
                     if (!v.isDuplicate(registration)) {
-                        return new ElectionState(v).addRegistration(registration, termCounter(topic)::incrementAndGet);
+                        return new ElectionState(v).addRegistration(
+                                topic, registration, termCounter(topic)::incrementAndGet);
                     } else {
                         return v;
                     }
@@ -183,14 +187,15 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.Withdraw} commit.
+     *
      * @param commit withdraw commit
      */
     public void withdraw(Commit<? extends Withdraw> commit) {
         try {
             String topic = commit.value().topic();
             Leadership oldLeadership = leadership(topic);
-            elections.computeIfPresent(topic, (k, v) -> v.cleanup(commit.session(),
-                    termCounter(topic)::incrementAndGet));
+            elections.computeIfPresent(topic, (k, v) -> v.cleanup(
+                    topic, commit.session(), termCounter(topic)::incrementAndGet));
             Leadership newLeadership = leadership(topic);
             if (!Objects.equal(oldLeadership, newLeadership)) {
                 notifyLeadershipChange(oldLeadership, newLeadership);
@@ -203,6 +208,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.Anoint} commit.
+     *
      * @param commit anoint commit
      * @return {@code true} if changes were made and the transfer occurred; {@code false} if it did not.
      */
@@ -228,6 +234,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.Promote} commit.
+     *
      * @param commit promote commit
      * @return {@code true} if changes desired end state is achieved.
      */
@@ -253,6 +260,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.Evict} commit.
+     *
      * @param commit evict commit
      */
     public void evict(Commit<? extends Evict> commit) {
@@ -277,6 +285,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.GetLeadership} commit.
+     *
      * @param commit GetLeadership commit
      * @return leader
      */
@@ -292,6 +301,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations.GetElectedTopics} commit.
+     *
      * @param commit commit entry
      * @return set of topics for which the node is the leader
      */
@@ -310,6 +320,7 @@
 
     /**
      * Applies an {@link AtomixLeaderElectorOperations#GET_ALL_LEADERSHIPS} commit.
+     *
      * @param commit GetAllLeaderships commit
      * @return topic to leader mapping
      */
@@ -346,7 +357,7 @@
         List<Change<Leadership>> changes = Lists.newArrayList();
         topics.forEach(topic -> {
             Leadership oldLeadership = leadership(topic);
-            elections.compute(topic, (k, v) -> v.cleanup(session, termCounter(topic)::incrementAndGet));
+            elections.compute(topic, (k, v) -> v.cleanup(topic, session, termCounter(topic)::incrementAndGet));
             Leadership newLeadership = leadership(topic);
             if (!Objects.equal(oldLeadership, newLeadership)) {
                 changes.add(new Change<>(oldLeadership, newLeadership));
@@ -381,17 +392,20 @@
         }
     }
 
-    private static class ElectionState {
+    private class ElectionState {
         final Registration leader;
         final long term;
         final long termStartTime;
         final List<Registration> registrations;
+        transient Map<String, ElectionState> elections;
 
-        public ElectionState(Registration registration, Supplier<Long> termCounter) {
+        public ElectionState(Registration registration, Supplier<Long> termCounter,
+                             Map<String, ElectionState> elections) {
             registrations = Arrays.asList(registration);
             term = termCounter.get();
             termStartTime = System.currentTimeMillis();
             leader = registration;
+            this.elections = elections;
         }
 
         public ElectionState(ElectionState other) {
@@ -399,19 +413,22 @@
             leader = other.leader;
             term = other.term;
             termStartTime = other.termStartTime;
+            elections = other.elections;
         }
 
         public ElectionState(List<Registration> registrations,
                 Registration leader,
                 long term,
-                long termStartTime) {
+                long termStartTime,
+                Map<String, ElectionState> elections) {
             this.registrations = Lists.newArrayList(registrations);
             this.leader = leader;
             this.term = term;
             this.termStartTime = termStartTime;
+            this.elections = elections;
         }
 
-        public ElectionState cleanup(RaftSession session, Supplier<Long> termCounter) {
+        public ElectionState cleanup(String topic, RaftSession session, Supplier<Long> termCounter) {
             Optional<Registration> registration =
                     registrations.stream().filter(r -> r.sessionId() == session.sessionId().id()).findFirst();
             if (registration.isPresent()) {
@@ -421,15 +438,32 @@
                                 .collect(Collectors.toList());
                 if (leader.sessionId() == session.sessionId().id()) {
                     if (!updatedRegistrations.isEmpty()) {
+                        updatedRegistrations.sort((a, b) -> {
+                            long aCount = elections.entrySet().stream()
+                                    .filter(entry -> !entry.getKey().equals(topic)
+                                            && entry.getValue().leader.nodeId.id()
+                                            .equals(sessions().getSession(a.sessionId).memberId().id()))
+                                    .count();
+                            long bCount = elections.entrySet().stream()
+                                    .filter(entry -> !entry.getKey().equals(topic)
+                                            && entry.getValue().leader.nodeId.id()
+                                            .equals(sessions().getSession(b.sessionId).memberId().id()))
+                                    .count();
+                            return ComparisonChain.start()
+                                    .compare(aCount, bCount)
+                                    .compare(a.sessionId, b.sessionId)
+                                    .result();
+                        });
                         return new ElectionState(updatedRegistrations,
                                 updatedRegistrations.get(0),
                                 termCounter.get(),
-                                System.currentTimeMillis());
+                                System.currentTimeMillis(),
+                                elections);
                     } else {
-                        return new ElectionState(updatedRegistrations, null, term, termStartTime);
+                        return new ElectionState(updatedRegistrations, null, term, termStartTime, elections);
                     }
                 } else {
-                    return new ElectionState(updatedRegistrations, leader, term, termStartTime);
+                    return new ElectionState(updatedRegistrations, leader, term, termStartTime, elections);
                 }
             } else {
                 return this;
@@ -449,12 +483,13 @@
                         return new ElectionState(updatedRegistrations,
                                 updatedRegistrations.get(0),
                                 termCounter.get(),
-                                System.currentTimeMillis());
+                                System.currentTimeMillis(),
+                                elections);
                     } else {
-                        return new ElectionState(updatedRegistrations, null, term, termStartTime);
+                        return new ElectionState(updatedRegistrations, null, term, termStartTime, elections);
                     }
                 } else {
-                    return new ElectionState(updatedRegistrations, leader, term, termStartTime);
+                    return new ElectionState(updatedRegistrations, leader, term, termStartTime, elections);
                 }
             } else {
                 return this;
@@ -478,15 +513,40 @@
             return registrations.stream().map(registration -> registration.nodeId()).collect(Collectors.toList());
         }
 
-        public ElectionState addRegistration(Registration registration, Supplier<Long> termCounter) {
+        public ElectionState addRegistration(String topic, Registration registration, Supplier<Long> termCounter) {
             if (!registrations.stream().anyMatch(r -> r.sessionId() == registration.sessionId())) {
                 List<Registration> updatedRegistrations = new LinkedList<>(registrations);
                 updatedRegistrations.add(registration);
-                boolean newLeader = leader == null;
+                updatedRegistrations.sort((a, b) -> {
+                    long aCount = elections.entrySet().stream()
+                            .filter(entry -> !entry.getKey().equals(topic)
+                                    && entry.getValue().leader.nodeId.id()
+                                    .equals(sessions().getSession(a.sessionId).memberId().id()))
+                            .count();
+                    long bCount = elections.entrySet().stream()
+                            .filter(entry -> !entry.getKey().equals(topic)
+                                    && entry.getValue().leader.nodeId.id()
+                                    .equals(sessions().getSession(b.sessionId).memberId().id()))
+                            .count();
+                    return ComparisonChain.start()
+                            .compare(aCount, bCount)
+                            .compare(a.sessionId, b.sessionId)
+                            .result();
+                });
+                Registration firstRegistration = updatedRegistrations.get(0);
+                Registration leader = this.leader;
+                long term = this.term;
+                long termStartTime = this.termStartTime;
+                if (leader == null || !leader.equals(firstRegistration)) {
+                    leader = firstRegistration;
+                    term = termCounter.get();
+                    termStartTime = System.currentTimeMillis();
+                }
                 return new ElectionState(updatedRegistrations,
-                        newLeader ? registration : leader,
-                        newLeader ? termCounter.get() : term,
-                        newLeader ? System.currentTimeMillis() : termStartTime);
+                        leader,
+                        term,
+                        termStartTime,
+                        elections);
             }
             return this;
         }
@@ -500,7 +560,8 @@
                 return new ElectionState(registrations,
                         newLeader,
                         termCounter.incrementAndGet(),
-                        System.currentTimeMillis());
+                        System.currentTimeMillis(),
+                        elections);
             } else {
                 return this;
             }
@@ -519,7 +580,8 @@
             return new ElectionState(updatedRegistrations,
                     leader,
                     term,
-                    termStartTime);
+                    termStartTime,
+                    elections);
 
         }
     }
diff --git a/core/store/primitives/src/test/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorTest.java b/core/store/primitives/src/test/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorTest.java
index b0ec22d..3eda617 100644
--- a/core/store/primitives/src/test/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorTest.java
+++ b/core/store/primitives/src/test/java/org/onosproject/store/primitives/resources/impl/AtomixLeaderElectorTest.java
@@ -36,9 +36,9 @@
  */
 public class AtomixLeaderElectorTest extends AtomixTestBase<AtomixLeaderElector> {
 
-    NodeId node1 = new NodeId("node1");
-    NodeId node2 = new NodeId("node2");
-    NodeId node3 = new NodeId("node3");
+    NodeId node1 = new NodeId("4");
+    NodeId node2 = new NodeId("5");
+    NodeId node3 = new NodeId("6");
 
     @Override
     protected RaftService createService() {
@@ -64,13 +64,20 @@
             assertEquals(node1, result.candidates().get(0));
         }).join();
 
-        AtomixLeaderElector elector2 = newPrimitive("test-elector-run");
-        elector2.run("foo", node2).thenAccept(result -> {
+        elector1.run("bar", node1).thenAccept(result -> {
             assertEquals(node1, result.leaderNodeId());
             assertEquals(1, result.leader().term());
-            assertEquals(2, result.candidates().size());
+            assertEquals(1, result.candidates().size());
             assertEquals(node1, result.candidates().get(0));
-            assertEquals(node2, result.candidates().get(1));
+        }).join();
+
+        AtomixLeaderElector elector2 = newPrimitive("test-elector-run");
+        elector2.run("bar", node2).thenAccept(result -> {
+            assertEquals(node2, result.leaderNodeId());
+            assertEquals(2, result.leader().term());
+            assertEquals(2, result.candidates().size());
+            assertEquals(node2, result.candidates().get(0));
+            assertEquals(node1, result.candidates().get(1));
         }).join();
     }
 
@@ -261,6 +268,50 @@
     }
 
     @Test
+    public void testLeaderBalance() throws Throwable {
+        AtomixLeaderElector elector1 = newPrimitive("test-elector-leader-session-close");
+        elector1.run("foo", node1).join();
+        elector1.run("bar", node1).join();
+        elector1.run("baz", node1).join();
+
+        AtomixLeaderElector elector2 = newPrimitive("test-elector-leader-session-close");
+        elector2.run("foo", node2).join();
+        elector2.run("bar", node2).join();
+        elector2.run("baz", node2).join();
+
+        AtomixLeaderElector elector3 = newPrimitive("test-elector-leader-session-close");
+        elector3.run("foo", node3).join();
+        elector3.run("bar", node3).join();
+        elector3.run("baz", node3).join();
+
+        LeaderEventListener listener = new LeaderEventListener();
+        elector2.addChangeListener(listener).join();
+
+        elector1.proxy.close();
+
+        listener.nextEvent().thenAccept(result -> {
+            assertEquals(node3, result.newValue().leaderNodeId());
+            assertEquals(2, result.newValue().candidates().size());
+            assertEquals(node3, result.newValue().candidates().get(0));
+            assertEquals(node2, result.newValue().candidates().get(1));
+        }).join();
+
+        listener.nextEvent().thenAccept(result -> {
+            assertEquals(node2, result.newValue().leaderNodeId());
+            assertEquals(2, result.newValue().candidates().size());
+            assertEquals(node2, result.newValue().candidates().get(0));
+            assertEquals(node3, result.newValue().candidates().get(1));
+        });
+
+        listener.nextEvent().thenAccept(result -> {
+            assertEquals(node2, result.newValue().leaderNodeId());
+            assertEquals(2, result.newValue().candidates().size());
+            assertEquals(node2, result.newValue().candidates().get(0));
+            assertEquals(node3, result.newValue().candidates().get(1));
+        }).join();
+    }
+
+    @Test
     public void testQueries() throws Throwable {
         leaderElectorQueryTests();
     }