Always to multios, correctly support duplication of file descriptors’ channel

git-svn-id: https://svn.apache.org/repos/asf/felix/trunk@1736049 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/gogo/runtime/src/main/java/org/apache/felix/gogo/runtime/Pipe.java b/gogo/runtime/src/main/java/org/apache/felix/gogo/runtime/Pipe.java
index 492a198..0073894 100644
--- a/gogo/runtime/src/main/java/org/apache/felix/gogo/runtime/Pipe.java
+++ b/gogo/runtime/src/main/java/org/apache/felix/gogo/runtime/Pipe.java
@@ -26,6 +26,7 @@
 import java.nio.channels.ByteChannel;
 import java.nio.channels.Channel;
 import java.nio.channels.Channels;
+import java.nio.channels.ClosedChannelException;
 import java.nio.channels.ReadableByteChannel;
 import java.nio.channels.WritableByteChannel;
 import java.nio.file.Files;
@@ -37,6 +38,7 @@
 import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
@@ -140,41 +142,30 @@
         if (fd == 2 && (readWrite & WRITE) == 0) {
             throw new IllegalArgumentException("Stderr is not writable");
         }
-        // TODO: externalize
-        boolean multios = true;
-        if (multios) {
-            if (streams[fd] != null && (readWrite & READ) != 0 && (readWrite & WRITE) != 0) {
-                throw new IllegalArgumentException("Can not do multios with read/write streams");
-            }
-            // If channel is inherited (for example standard input / output), replace it
-            if (streams[fd] != null && !toclose[fd]) {
-                streams[fd] = ch;
-                toclose[fd] = true;
-            }
-            // Else do multios
-            else {
-                MultiChannel mrbc;
-                // If the channel is already multios
-                if (streams[fd] instanceof MultiChannel) {
-                    mrbc = (MultiChannel) streams[fd];
-                }
-                // Else create a multios channel
-                else {
-                    mrbc = new MultiChannel();
-                    mrbc.addChannel(streams[fd], toclose[fd]);
-                    streams[fd] = mrbc;
-                    toclose[fd] = true;
-                }
-                mrbc.addChannel(ch, true);
-            }
+        if (streams[fd] != null && (readWrite & READ) != 0 && (readWrite & WRITE) != 0) {
+            throw new IllegalArgumentException("Can not do multios with read/write streams");
         }
-        else {
-            if (streams[fd] != null && toclose[fd]) {
-                streams[fd].close();
-            }
+        // If channel is inherited (for example standard input / output), replace it
+        if (streams[fd] != null && !toclose[fd]) {
             streams[fd] = ch;
             toclose[fd] = true;
         }
+        // Else do multios
+        else {
+            MultiChannel mrbc;
+            // If the channel is already multios
+            if (streams[fd] instanceof MultiChannel) {
+                mrbc = (MultiChannel) streams[fd];
+            }
+            // Else create a multios channel
+            else {
+                mrbc = new MultiChannel();
+                mrbc.addChannel(streams[fd], toclose[fd]);
+                streams[fd] = mrbc;
+                toclose[fd] = true;
+            }
+            mrbc.addChannel(ch, true);
+        }
     }
 
     @Override
@@ -248,10 +239,18 @@
                     if (streams[fd0] != null && toclose[fd0]) {
                         streams[fd0].close();
                     }
-                    streams[fd0] = streams[fd1];
-                    // TODO: this is wrong, we should keep a counter somehow so that the
-                    // stream is closed when both are closed
-                    toclose[fd0] = false;
+                    // If the stream has to be closed, close it when both streams are closed
+                    if (toclose[fd1]) {
+                        Channel channel = streams[fd1];
+                        AtomicInteger references = new AtomicInteger();
+                        streams[fd0] = new RefByteChannel(channel, references);
+                        streams[fd1] = new RefByteChannel(channel, references);
+                        toclose[fd0] = true;
+                    }
+                    else {
+                        streams[fd0] = streams[fd1];
+                        toclose[fd0] = false;
+                    }
                 }
                 else if ((m = Pattern.compile("([0-9])?<(>)?").matcher(t)).matches()) {
                     int fd = 0;
@@ -317,9 +316,6 @@
         }
         catch (Exception e)
         {
-            // TODO: use shell name instead of 'gogo'
-            // TODO: use color if not redirected
-            // TODO: use conversion ?
             String msg = "gogo: " + e.getClass().getSimpleName() + ": " + e.getMessage() + "\n";
             try {
                 errChannel.write(ByteBuffer.wrap(msg.getBytes()));
@@ -365,6 +361,50 @@
         return mch;
     }
 
+    private class RefByteChannel implements ByteChannel {
+
+        private final Channel channel;
+        private final AtomicInteger references;
+        private final AtomicBoolean closed = new AtomicBoolean(false);
+
+        public RefByteChannel(Channel channel, AtomicInteger references) {
+            this.channel = channel;
+            this.references = references;
+            references.incrementAndGet();
+        }
+
+        @Override
+        public int read(ByteBuffer dst) throws IOException {
+            ensureOpen();
+            return ((ReadableByteChannel) channel).read(dst);
+        }
+
+        @Override
+        public int write(ByteBuffer src) throws IOException {
+            ensureOpen();
+            return ((WritableByteChannel) channel).write(src);
+        }
+
+        @Override
+        public boolean isOpen() {
+            return !closed.get();
+        }
+
+        private void ensureOpen() throws ClosedChannelException {
+            if (closed.get()) throw new ClosedChannelException();
+        }
+
+        @Override
+        public void close() throws IOException {
+            if (closed.compareAndSet(false, true)) {
+                if (references.decrementAndGet() == 0) {
+                    channel.close();
+                }
+            }
+        }
+
+    }
+
     private class MultiChannel implements ByteChannel {
         protected final List<Channel> channels = new ArrayList<>();
         protected final List<Channel> toClose = new ArrayList<>();