Fix for ONOS-5035

Change-Id: I0a8edb0c77d2803070dba10c06f83390e1a09832
diff --git a/core/store/dist/src/main/java/org/onosproject/store/host/impl/DistributedHostStore.java b/core/store/dist/src/main/java/org/onosproject/store/host/impl/DistributedHostStore.java
index 356df9c..72f2407 100644
--- a/core/store/dist/src/main/java/org/onosproject/store/host/impl/DistributedHostStore.java
+++ b/core/store/dist/src/main/java/org/onosproject/store/host/impl/DistributedHostStore.java
@@ -15,8 +15,17 @@
  */
 package org.onosproject.store.host.impl;
 
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
 
 import org.apache.felix.scr.annotations.Activate;
 import org.apache.felix.scr.annotations.Component;
@@ -44,24 +53,25 @@
 import org.onosproject.store.AbstractStore;
 import org.onosproject.store.serializers.KryoNamespaces;
 import org.onosproject.store.service.ConsistentMap;
+import org.onosproject.store.service.DistributedPrimitive.Status;
 import org.onosproject.store.service.MapEvent;
 import org.onosproject.store.service.MapEventListener;
 import org.onosproject.store.service.Serializer;
 import org.onosproject.store.service.StorageService;
 import org.slf4j.Logger;
 
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.function.Predicate;
-import java.util.stream.Collectors;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
 
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
+import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor;
+import static org.onlab.util.Tools.groupedThreads;
 import static org.onosproject.net.DefaultAnnotations.merge;
-import static org.onosproject.net.host.HostEvent.Type.*;
+import static org.onosproject.net.host.HostEvent.Type.HOST_ADDED;
+import static org.onosproject.net.host.HostEvent.Type.HOST_MOVED;
+import static org.onosproject.net.host.HostEvent.Type.HOST_REMOVED;
+import static org.onosproject.net.host.HostEvent.Type.HOST_UPDATED;
 import static org.slf4j.LoggerFactory.getLogger;
 
 /**
@@ -80,10 +90,15 @@
 
     private ConsistentMap<HostId, DefaultHost> hostsConsistentMap;
     private Map<HostId, DefaultHost> hosts;
+    private Map<IpAddress, Set<Host>> hostsByIp;
 
     private MapEventListener<HostId, DefaultHost> hostLocationTracker =
             new HostLocationTracker();
 
+    private ScheduledExecutorService executor;
+
+    private Consumer<Status> statusChangeListener;
+
     @Activate
     public void activate() {
         KryoNamespace.Builder hostSerializer = KryoNamespace.newBuilder()
@@ -100,6 +115,14 @@
 
         hostsConsistentMap.addListener(hostLocationTracker);
 
+        executor = newSingleThreadScheduledExecutor(groupedThreads("onos/hosts", "store", log));
+        statusChangeListener = status -> {
+            if (status == Status.ACTIVE) {
+                executor.execute(this::loadHostsByIp);
+            }
+        };
+        hostsConsistentMap.addStatusChangeListener(statusChangeListener);
+        loadHostsByIp();
         log.info("Started");
     }
 
@@ -110,6 +133,20 @@
         log.info("Stopped");
     }
 
+     private void loadHostsByIp() {
+        hostsByIp = new ConcurrentHashMap<IpAddress, Set<Host>>();
+         hostsConsistentMap.asJavaMap().values().forEach(host -> {
+            host.ipAddresses().forEach(ip -> {
+                Set<Host> existingHosts = hostsByIp.get(ip);
+                if (existingHosts == null) {
+                    hostsByIp.put(ip, addHosts(host));
+                } else {
+                    existingHosts.add(host);
+                }
+            });
+        });
+    }
+
     private boolean shouldUpdate(DefaultHost existingHost,
                                  ProviderId providerId,
                                  HostId hostId,
@@ -210,6 +247,7 @@
                 if (addresses != null && addresses.contains(ipAddress)) {
                     addresses = new HashSet<>(existingHost.ipAddresses());
                     addresses.remove(ipAddress);
+                    removeIpFromHostsByIp(existingHost, ipAddress);
                     return new DefaultHost(existingHost.providerId(),
                             hostId,
                             existingHost.mac(),
@@ -253,7 +291,8 @@
 
     @Override
     public Set<Host> getHosts(IpAddress ip) {
-        return filter(hosts.values(), host -> host.ipAddresses().contains(ip));
+        Set<Host> hosts = hostsByIp.get(ip);
+        return hosts != null ? ImmutableSet.copyOf(hosts) : ImmutableSet.of();
     }
 
     @Override
@@ -278,18 +317,70 @@
         return collection.stream().filter(predicate).collect(Collectors.toSet());
     }
 
+    private Set<Host> addHosts(Host host) {
+        Set<Host> hosts = Sets.newConcurrentHashSet();
+        hosts.add(host);
+        return hosts;
+    }
+
+    private Set<Host> updateHosts(Set<Host> existingHosts, Host host) {
+        Iterator<Host> iterator = existingHosts.iterator();
+        while (iterator.hasNext()) {
+            Host existingHost = iterator.next();
+            if (existingHost.id().equals(host.id())) {
+                iterator.remove();
+            }
+        }
+        existingHosts.add(host);
+        return existingHosts;
+    }
+
+    private Set<Host> removeHosts(Set<Host> existingHosts, Host host) {
+        if (existingHosts != null) {
+            Iterator<Host> iterator = existingHosts.iterator();
+            while (iterator.hasNext()) {
+                Host existingHost = iterator.next();
+                if (existingHost.id().equals(host.id())) {
+                    iterator.remove();
+                }
+            }
+        }
+
+        if (existingHosts.isEmpty()) {
+            return null;
+        }
+        return existingHosts;
+    }
+
+    private void updateHostsByIp(DefaultHost host) {
+        host.ipAddresses().forEach(ip -> {
+            hostsByIp.compute(ip, (k, v) -> v == null ? addHosts(host)
+                                                      : updateHosts(v, host));
+        });
+    }
+
+    private void removeHostsByIp(DefaultHost host) {
+        host.ipAddresses().forEach(ip -> {
+            hostsByIp.computeIfPresent(ip, (k, v) -> removeHosts(v, host));
+        });
+    }
+
+    private void removeIpFromHostsByIp(DefaultHost host, IpAddress ip) {
+        hostsByIp.computeIfPresent(ip, (k, v) -> removeHosts(v, host));
+    }
+
     private class HostLocationTracker implements MapEventListener<HostId, DefaultHost> {
         @Override
         public void event(MapEvent<HostId, DefaultHost> event) {
-            Host host;
+            DefaultHost host = checkNotNull(event.value().value());
             switch (event.type()) {
                 case INSERT:
-                    host = checkNotNull(event.newValue().value());
+                    updateHostsByIp(host);
                     notifyDelegate(new HostEvent(HOST_ADDED, host));
                     break;
                 case UPDATE:
-                    host = checkNotNull(event.newValue().value());
-                    Host prevHost = checkNotNull(event.oldValue().value());
+                    updateHostsByIp(host);
+                    DefaultHost prevHost = checkNotNull(event.oldValue().value());
                     if (!Objects.equals(prevHost.location(), host.location())) {
                         notifyDelegate(new HostEvent(HOST_MOVED, host, prevHost));
                     } else if (!Objects.equals(prevHost, host)) {
@@ -297,7 +388,7 @@
                     }
                     break;
                 case REMOVE:
-                    host = checkNotNull(event.oldValue().value());
+                    updateHostsByIp(host);
                     notifyDelegate(new HostEvent(HOST_REMOVED, host));
                     break;
                 default:
diff --git a/core/store/dist/src/test/java/org/onosproject/store/host/impl/DistributedHostStoreTest.java b/core/store/dist/src/test/java/org/onosproject/store/host/impl/DistributedHostStoreTest.java
index 9f7338d..4c2e1db 100644
--- a/core/store/dist/src/test/java/org/onosproject/store/host/impl/DistributedHostStoreTest.java
+++ b/core/store/dist/src/test/java/org/onosproject/store/host/impl/DistributedHostStoreTest.java
@@ -15,6 +15,9 @@
  */
 package org.onosproject.store.host.impl;
 
+import java.util.HashSet;
+import java.util.Set;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -28,8 +31,7 @@
 import org.onosproject.net.provider.ProviderId;
 import org.onosproject.store.service.TestStorageService;
 
-import java.util.HashSet;
-import java.util.Set;
+import com.google.common.collect.Sets;
 
 import static junit.framework.TestCase.assertTrue;
 import static org.junit.Assert.assertFalse;
@@ -42,6 +44,7 @@
     private DistributedHostStore ecXHostStore;
 
     private static final HostId HOSTID = HostId.hostId(MacAddress.valueOf("1a:1a:1a:1a:1a:1a"));
+    private static final HostId HOSTID1 = HostId.hostId(MacAddress.valueOf("1a:1a:1a:1a:1a:1b"));
 
     private static final IpAddress IP1 = IpAddress.valueOf("10.2.0.2");
     private static final IpAddress IP2 = IpAddress.valueOf("10.2.0.3");
@@ -70,10 +73,7 @@
         ips.add(IP1);
         ips.add(IP2);
 
-        HostDescription description = new DefaultHostDescription(HOSTID.mac(),
-                                                                    HOSTID.vlanId(),
-                                                                    HostLocation.NONE,
-                                                                    ips);
+        HostDescription description = createHostDesc(HOSTID, ips);
         ecXHostStore.createOrUpdateHost(PID, HOSTID, description, false);
         ecXHostStore.removeIp(HOSTID, IP1);
         Host host = ecXHostStore.getHost(HOSTID);
@@ -82,4 +82,47 @@
         assertTrue(host.ipAddresses().contains(IP2));
     }
 
+    @Test
+    public void testAddHostByIp() {
+        Set<IpAddress> ips = new HashSet<>();
+        ips.add(IP1);
+        ips.add(IP2);
+
+        HostDescription description = createHostDesc(HOSTID, ips);
+        ecXHostStore.createOrUpdateHost(PID, HOSTID, description, false);
+
+        Set<Host> hosts = ecXHostStore.getHosts(IP1);
+
+        assertFalse(hosts.size() > 1);
+        assertTrue(hosts.size() == 1);
+
+        HostDescription description1 = createHostDesc(HOSTID1, Sets.newHashSet(IP2));
+        ecXHostStore.createOrUpdateHost(PID, HOSTID1, description1, false);
+
+        Set<Host> hosts1 = ecXHostStore.getHosts(IP2);
+
+        assertFalse(hosts1.size() < 1);
+        assertTrue(hosts1.size() == 2);
+    }
+
+    @Test
+    public void testRemoveHostByIp() {
+        Set<IpAddress> ips = new HashSet<>();
+        ips.add(IP1);
+        ips.add(IP2);
+
+        HostDescription description = createHostDesc(HOSTID, ips);
+        ecXHostStore.createOrUpdateHost(PID, HOSTID, description, false);
+        ecXHostStore.removeIp(HOSTID, IP1);
+        Set<Host> hosts = ecXHostStore.getHosts(IP1);
+        assertTrue(hosts.size() == 0);
+    }
+
+    private HostDescription createHostDesc(HostId hostId, Set<IpAddress> ips) {
+        return new DefaultHostDescription(hostId.mac(),
+                                          hostId.vlanId(),
+                                          HostLocation.NONE,
+                                          ips);
+    }
+
 }