blob: c9744863c2912e0c59ac97b7fa9b1c0cde171099 [file] [log] [blame]
Madan Jampani08822c42014-11-04 17:17:46 -08001package org.onlab.onos.store.service.impl;
2
3import java.util.ArrayList;
4import java.util.List;
5import java.util.Map;
6import java.util.Set;
7
8import net.kuujo.copycat.Command;
9import net.kuujo.copycat.Query;
10import net.kuujo.copycat.StateMachine;
11
12import org.onlab.onos.store.serializers.KryoSerializer;
13import org.onlab.onos.store.service.ReadRequest;
14import org.onlab.onos.store.service.ReadResult;
Madan Jampani37c2e702014-11-04 18:11:10 -080015import org.onlab.onos.store.service.VersionedValue;
Madan Jampani08822c42014-11-04 17:17:46 -080016import org.onlab.onos.store.service.WriteRequest;
17import org.onlab.onos.store.service.WriteResult;
18import org.onlab.util.KryoNamespace;
19
20import com.google.common.collect.Maps;
21
22public class DatabaseStateMachine implements StateMachine {
23
24 public static final KryoSerializer SERIALIZER = new KryoSerializer() {
25 @Override
26 protected void setupKryoPool() {
27 serializerPool = KryoNamespace.newBuilder()
28 .register(VersionedValue.class)
29 .register(State.class)
Madan Jampani9b19a822014-11-04 21:37:13 -080030 .register(ClusterMessagingProtocol.COMMON)
Madan Jampani08822c42014-11-04 17:17:46 -080031 .build()
32 .populate(1);
33 }
34 };
35
36 private State state = new State();
37
38 @Command
39 public boolean createTable(String tableName) {
40 return state.getTables().putIfAbsent(tableName, Maps.newHashMap()) == null;
41 }
42
43 @Command
44 public boolean dropTable(String tableName) {
45 return state.getTables().remove(tableName) != null;
46 }
47
48 @Command
49 public boolean dropAllTables() {
50 state.getTables().clear();
51 return true;
52 }
53
54 @Query
55 public Set<String> listTables() {
56 return state.getTables().keySet();
57 }
58
59 @Query
60 public List<InternalReadResult> read(List<ReadRequest> requests) {
61 List<InternalReadResult> results = new ArrayList<>(requests.size());
62 for (ReadRequest request : requests) {
63 Map<String, VersionedValue> table = state.getTables().get(request.tableName());
64 if (table == null) {
65 results.add(new InternalReadResult(InternalReadResult.Status.NO_SUCH_TABLE, null));
66 continue;
67 }
68 VersionedValue value = table.get(request.key());
69 results.add(new InternalReadResult(
70 InternalReadResult.Status.OK,
71 new ReadResult(
72 request.tableName(),
73 request.key(),
74 value)));
75 }
76 return results;
77 }
78
79 @Command
80 public List<InternalWriteResult> write(List<WriteRequest> requests) {
81 boolean abort = false;
82 List<InternalWriteResult.Status> validationResults = new ArrayList<>(requests.size());
83 for (WriteRequest request : requests) {
84 Map<String, VersionedValue> table = state.getTables().get(request.tableName());
85 if (table == null) {
86 validationResults.add(InternalWriteResult.Status.NO_SUCH_TABLE);
87 abort = true;
88 continue;
89 }
90 VersionedValue value = table.get(request.key());
91 if (value == null) {
92 if (request.oldValue() != null) {
93 validationResults.add(InternalWriteResult.Status.PREVIOUS_VALUE_MISMATCH);
94 abort = true;
95 continue;
96 } else if (request.previousVersion() >= 0) {
97 validationResults.add(InternalWriteResult.Status.OPTIMISTIC_LOCK_FAILURE);
98 abort = true;
99 continue;
100 }
101 }
102 if (request.previousVersion() >= 0 && value.version() != request.previousVersion()) {
103 validationResults.add(InternalWriteResult.Status.OPTIMISTIC_LOCK_FAILURE);
104 abort = true;
105 continue;
106 }
107
108 validationResults.add(InternalWriteResult.Status.OK);
109 }
110
111 List<InternalWriteResult> results = new ArrayList<>(requests.size());
112
113 if (abort) {
114 for (InternalWriteResult.Status validationResult : validationResults) {
115 if (validationResult == InternalWriteResult.Status.OK) {
116 results.add(new InternalWriteResult(InternalWriteResult.Status.ABORTED, null));
117 } else {
118 results.add(new InternalWriteResult(validationResult, null));
119 }
120 }
121 return results;
122 }
123
124 for (WriteRequest request : requests) {
125 Map<String, VersionedValue> table = state.getTables().get(request.tableName());
126 synchronized (table) {
127 VersionedValue previousValue =
128 table.put(request.key(), new VersionedValue(request.newValue(), state.nextVersion()));
129 results.add(new InternalWriteResult(
130 InternalWriteResult.Status.OK,
131 new WriteResult(request.tableName(), request.key(), previousValue)));
132 }
133 }
134 return results;
135 }
136
137 public class State {
138
139 private final Map<String, Map<String, VersionedValue>> tables =
140 Maps.newHashMap();
141 private long versionCounter = 1;
142
143 Map<String, Map<String, VersionedValue>> getTables() {
144 return tables;
145 }
146
147 long nextVersion() {
148 return versionCounter++;
149 }
150 }
151
152 @Override
153 public byte[] takeSnapshot() {
154 try {
155 return SERIALIZER.encode(state);
156 } catch (Exception e) {
157 e.printStackTrace();
158 return null;
159 }
160 }
161
162 @Override
163 public void installSnapshot(byte[] data) {
164 try {
165 this.state = SERIALIZER.decode(data);
166 } catch (Exception e) {
167 e.printStackTrace();
168 }
169 }
170}