diff --git a/src/main/java/org/onosproject/segmentrouting/McastHandler.java b/src/main/java/org/onosproject/segmentrouting/McastHandler.java
index 96a2337..f78722d 100644
--- a/src/main/java/org/onosproject/segmentrouting/McastHandler.java
+++ b/src/main/java/org/onosproject/segmentrouting/McastHandler.java
@@ -197,7 +197,6 @@
         ConnectPoint source = mcastRouteInfo.source().orElse(null);
         ConnectPoint sink = mcastRouteInfo.sink().orElse(null);
         IpAddress mcastIp = mcastRouteInfo.route().group();
-        VlanId assignedVlan = assignedVlan();
 
         // When source and sink are on the same device
         if (source.deviceId().equals(sink.deviceId())) {
@@ -206,12 +205,12 @@
                 log.warn("Sink is on the same port of source. Abort");
                 return;
             }
-            removePortFromDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan);
+            removePortFromDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan(source));
             return;
         }
 
         // Process the egress device
-        boolean isLast = removePortFromDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan);
+        boolean isLast = removePortFromDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan(null));
         if (isLast) {
             mcastRoleStore.remove(new McastStoreKey(mcastIp, sink.deviceId()));
         }
@@ -224,7 +223,8 @@
             for (Link link : links) {
                 if (isLast) {
                     isLast = removePortFromDevice(link.src().deviceId(), link.src().port(),
-                            mcastIp, assignedVlan);
+                            mcastIp,
+                            assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
                     mcastRoleStore.remove(new McastStoreKey(mcastIp, link.src().deviceId()));
                 }
             }
@@ -240,10 +240,8 @@
      */
     private void processSinkAddedInternal(ConnectPoint source, ConnectPoint sink,
             IpAddress mcastIp) {
-        VlanId assignedVlan = assignedVlan();
-
         // Process the ingress device
-        addFilterToDevice(source.deviceId(), source.port(), assignedVlan);
+        addFilterToDevice(source.deviceId(), source.port(), assignedVlan(source));
 
         // When source and sink are on the same device
         if (source.deviceId().equals(sink.deviceId())) {
@@ -252,7 +250,7 @@
                 log.warn("Sink is on the same port of source. Abort");
                 return;
             }
-            addPortToDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan);
+            addPortToDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan(source));
             mcastRoleStore.put(new McastStoreKey(mcastIp, sink.deviceId()), McastRole.INGRESS);
             return;
         }
@@ -265,12 +263,13 @@
                     "Path in leaf-spine topology should always be two hops: ", links);
 
             links.forEach(link -> {
-                addFilterToDevice(link.dst().deviceId(), link.dst().port(), assignedVlan);
-                addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp, assignedVlan);
+                addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp,
+                        assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
+                addFilterToDevice(link.dst().deviceId(), link.dst().port(), assignedVlan(null));
             });
 
             // Process the egress device
-            addPortToDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan);
+            addPortToDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan(null));
 
             // Setup mcast roles
             mcastRoleStore.put(new McastStoreKey(mcastIp, source.deviceId()),
@@ -291,8 +290,6 @@
      * @param affectedLink Link that is going down
      */
     protected void processLinkDown(Link affectedLink) {
-        VlanId assignedVlan = assignedVlan();
-
         getAffectedGroups(affectedLink).forEach(mcastIp -> {
             // Find out the ingress, transit and egress device of affected group
             DeviceId ingressDevice = getDevice(mcastIp, McastRole.INGRESS)
@@ -300,19 +297,23 @@
             DeviceId transitDevice = getDevice(mcastIp, McastRole.TRANSIT)
                     .stream().findAny().orElse(null);
             Set<DeviceId> egressDevices = getDevice(mcastIp, McastRole.EGRESS);
-            if (ingressDevice == null || transitDevice == null || egressDevices == null) {
-                log.warn("Missing ingress {}, transit {}, or egress {} devices",
-                        ingressDevice, transitDevice, egressDevices);
+            ConnectPoint source = getSource(mcastIp);
+
+            // Do not proceed if any of these info is missing
+            if (ingressDevice == null || transitDevice == null
+                    || egressDevices == null || source == null) {
+                log.warn("Missing ingress {}, transit {}, egress {} devices or source {}",
+                        ingressDevice, transitDevice, egressDevices, source);
                 return;
             }
 
             // Remove entire transit
-            removeGroupFromDevice(transitDevice, mcastIp, assignedVlan);
+            removeGroupFromDevice(transitDevice, mcastIp, assignedVlan(null));
 
             // Remove transit-facing port on ingress device
             PortNumber ingressTransitPort = ingressTransitPort(mcastIp);
             if (ingressTransitPort != null) {
-                removePortFromDevice(ingressDevice, ingressTransitPort, mcastIp, assignedVlan);
+                removePortFromDevice(ingressDevice, ingressTransitPort, mcastIp, assignedVlan(source));
                 mcastRoleStore.remove(new McastStoreKey(mcastIp, transitDevice));
             }
 
@@ -322,8 +323,9 @@
                 if (mcastPath.isPresent()) {
                     List<Link> links = mcastPath.get().links();
                     links.forEach(link -> {
-                        addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp, assignedVlan);
-                        addFilterToDevice(link.dst().deviceId(), link.dst().port(), assignedVlan);
+                        addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp,
+                                assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
+                        addFilterToDevice(link.dst().deviceId(), link.dst().port(), assignedVlan(null));
                     });
                     // Setup new transit mcast role
                     mcastRoleStore.put(new McastStoreKey(mcastIp,
@@ -331,7 +333,7 @@
                 } else {
                     log.warn("Fail to recover egress device {} from link failure {}",
                             egressDevice, affectedLink);
-                    removeGroupFromDevice(egressDevice, mcastIp, assignedVlan);
+                    removeGroupFromDevice(egressDevice, mcastIp, assignedVlan(null));
                 }
             });
         });
@@ -521,7 +523,9 @@
         while (itNextObj.hasNext()) {
             Map.Entry<McastStoreKey, Versioned<NextObjective>> entry = itNextObj.next();
             if (entry.getKey().deviceId().equals(deviceId)) {
-                removeGroupFromDevice(entry.getKey().deviceId(), entry.getKey().mcastIp(), assignedVlan());
+                ConnectPoint source = getSource(entry.getKey().mcastIp());
+                removeGroupFromDevice(entry.getKey().deviceId(), entry.getKey().mcastIp(),
+                        assignedVlan(deviceId.equals(source.deviceId()) ? source : null));
                 itNextObj.remove();
             }
         }
@@ -693,6 +697,19 @@
     }
 
     /**
+     * Gets source connect point of given multicast group.
+     *
+     * @param mcastIp multicast IP
+     * @return source connect point or null if not found
+     */
+    private ConnectPoint getSource(IpAddress mcastIp) {
+        return srManager.multicastRouteService.getRoutes().stream()
+                .filter(mcastRoute -> mcastRoute.group().equals(mcastIp))
+                .map(mcastRoute -> srManager.multicastRouteService.fetchSource(mcastRoute))
+                .findAny().orElse(null);
+    }
+
+    /**
      * Gets groups which is affected by the link down event.
      *
      * @param link link going down
@@ -721,13 +738,27 @@
 
     /**
      * Gets assigned VLAN according to the value of egress VLAN.
+     * If connect point is specified, try to reuse the assigned VLAN on the connect point.
      *
-     * @return assigned VLAN
+     * @param cp connect point; Can be null if not specified
+     * @return assigned VLAN ID
      */
-    private VlanId assignedVlan() {
-        return (egressVlan().equals(VlanId.NONE)) ?
-                VlanId.vlanId(SegmentRoutingManager.ASSIGNED_VLAN_NO_SUBNET) :
-                egressVlan();
+    private VlanId assignedVlan(ConnectPoint cp) {
+        // Use the egressVlan if it is tagged
+        if (!egressVlan().equals(VlanId.NONE)) {
+            return egressVlan();
+        }
+        // Reuse unicast VLAN if the port has subnet configured
+        if (cp != null) {
+            Ip4Prefix portSubnet = srManager.deviceConfiguration
+                    .getPortSubnet(cp.deviceId(), cp.port());
+            VlanId unicastVlan = srManager.getSubnetAssignedVlanId(cp.deviceId(), portSubnet);
+            if (unicastVlan != null) {
+                return unicastVlan;
+            }
+        }
+        // By default, use VLAN_NO_SUBNET
+        return VlanId.vlanId(SegmentRoutingManager.ASSIGNED_VLAN_NO_SUBNET);
     }
 
     /**
