DATAGEODE-364 - Refactor SimpleGemfireRepository.count() to be resilient when no results exist.

This commit is contained in:
John Blum
2020-08-28 15:44:11 -07:00
parent 8d9a126fbe
commit 3f69965372
2 changed files with 179 additions and 119 deletions

View File

@@ -18,6 +18,7 @@ package org.springframework.data.gemfire.repository.support;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -60,24 +61,25 @@ public class SimpleGemfireRepository<T, ID> implements GemfireRepository<T, ID>
private final GemfireTemplate template;
/**
* Creates a new {@link SimpleGemfireRepository}.
* Constructs a new instance of {@link SimpleGemfireRepository} initialized with the {@link GemfireTemplate}
* and {@link EntityInformation}.
*
* @param template must not be {@literal null}.
* @param entityInformation must not be {@literal null}.
* @param template {@link GemfireTemplate} used to perform basic data access operations and simple OQL queries;
* must not be {@literal null}.
* @param entityInformation {@link EntityInformation} used to describe the entity; must not be {@literal null}.
* @throws IllegalArgumentException if {@link GemfireTemplate} or {@link EntityInformation} is {@literal null}.
* @see org.springframework.data.gemfire.GemfireTemplate
* @see org.springframework.data.repository.core.EntityInformation
*/
public SimpleGemfireRepository(GemfireTemplate template, EntityInformation<T, ID> entityInformation) {
Assert.notNull(template, "Template must not be null");
Assert.notNull(template, "GemfireTemplate must not be null");
Assert.notNull(entityInformation, "EntityInformation must not be null");
this.template = template;
this.entityInformation = entityInformation;
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#save(S)
*/
@Override
public <U extends T> U save(U entity) {
@@ -88,10 +90,16 @@ public class SimpleGemfireRepository<T, ID> implements GemfireRepository<T, ID>
return entity;
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#save(java.lang.Iterable)
*/
@Override
public T save(Wrapper<T, ID> wrapper) {
T entity = wrapper.getEntity();
this.template.put(wrapper.getKey(), entity);
return entity;
}
@Override
public <U extends T> Iterable<U> saveAll(Iterable<U> entities) {
@@ -104,55 +112,31 @@ public class SimpleGemfireRepository<T, ID> implements GemfireRepository<T, ID>
return entitiesToSave.values();
}
/*
* (non-Javadoc)
* @see org.springframework.data.gemfire.repository.GemfireRepository#save(org.springframework.data.gemfire.repository.Wrapper)
*/
@Override
public T save(Wrapper<T, ID> wrapper) {
T entity = wrapper.getEntity();
this.template.put(wrapper.getKey(), entity);
return entity;
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#count()
*/
@Override
public long count() {
SelectResults<Integer> results =
this.template.find(String.format("SELECT count(*) FROM %s", this.template.getRegion().getFullPath()));
String countQuery = String.format("SELECT count(*) FROM %s", this.template.getRegion().getFullPath());
return Long.valueOf(results.iterator().next());
SelectResults<Integer> results = this.template.find(countQuery);
return Optional.ofNullable(results)
.map(SelectResults::iterator)
.filter(Iterator::hasNext)
.map(Iterator::next)
.map(Long::valueOf)
.orElse(0L);
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#existsById(java.lang.Object)
*/
@Override
public boolean existsById(ID id) {
return findById(id).isPresent();
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#findById(java.lang.Object)
*/
@Override
public Optional<T> findById(ID id) {
return Optional.ofNullable(this.template.get(id));
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#findAll()
*/
@Override
public Collection<T> findAll() {
@@ -162,10 +146,6 @@ public class SimpleGemfireRepository<T, ID> implements GemfireRepository<T, ID>
return results.asList();
}
/*
* (non-Javadoc)
* @see org.springframework.data.gemfire.repository.GemfireRepository.sort(:org.springframework.data.domain.Sort)
*/
@Override
public Iterable<T> findAll(Sort sort) {
@@ -178,10 +158,6 @@ public class SimpleGemfireRepository<T, ID> implements GemfireRepository<T, ID>
return selectResults.asList();
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#findAllById(java.lang.Iterable)
*/
@Override
public Collection<T> findAllById(Iterable<ID> ids) {
@@ -191,80 +167,45 @@ public class SimpleGemfireRepository<T, ID> implements GemfireRepository<T, ID>
.filter(Objects::nonNull).collect(Collectors.toList());
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#deleteById(java.lang.Object)
*/
@Override
public void deleteById(ID id) {
this.template.remove(id);
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#delete(java.lang.Object)
*/
@Override
public void delete(T entity) {
deleteById(this.entityInformation.getRequiredId(entity));
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#delete(java.lang.Iterable)
*/
@Override
public void deleteAll(Iterable<? extends T> entities) {
entities.forEach(this::delete);
}
/*
* (non-Javadoc)
* @see org.apache.geode.cache.Region#getAttributes()
* @see org.apache.geode.cache.RegionAttributes#getDataPolicy()
*/
boolean isPartitioned(Region<?, ?> region) {
return region != null && region.getAttributes() != null
&& isPartitioned(region.getAttributes().getDataPolicy());
}
/*
* (non-Javadoc)
* @see org.apache.geode.cache.DataPolicy#withPartitioning()
*/
boolean isPartitioned(DataPolicy dataPolicy) {
return dataPolicy != null && dataPolicy.withPartitioning();
}
/*
* (non-Javadoc)
* @see org.apache.geode.cache.Region#getRegionService()
* @see org.apache.geode.cache.Cache#getCacheTransactionManager()
*/
boolean isTransactionPresent(Region<?, ?> region) {
return region.getRegionService() instanceof Cache
&& isTransactionPresent(((Cache) region.getRegionService()).getCacheTransactionManager());
}
/*
* (non-Javadoc)
* @see org.apache.geode.cache.CacheTransactionManager#exists()
*/
boolean isTransactionPresent(CacheTransactionManager cacheTransactionManager) {
return cacheTransactionManager != null && cacheTransactionManager.exists();
}
/* (non-Javadoc) */
<K> void doRegionClear(Region<K, ?> region) {
region.removeAll(region.keySet());
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.CrudRepository#deleteAll()
*/
@Override
public void deleteAll() {
this.template.execute((GemfireCallback<Void>) region -> {

View File

@@ -13,12 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.gemfire.repository.support;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
@@ -28,6 +25,8 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import java.util.ArrayList;
@@ -44,41 +43,43 @@ import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.apache.geode.cache.Cache;
import org.apache.geode.cache.CacheTransactionManager;
import org.apache.geode.cache.DataPolicy;
import org.apache.geode.cache.Region;
import org.apache.geode.cache.RegionAttributes;
import org.apache.geode.cache.query.SelectResults;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.data.gemfire.GemfireTemplate;
import org.springframework.data.gemfire.repository.Wrapper;
import org.springframework.data.gemfire.repository.sample.Animal;
import org.springframework.data.repository.core.EntityInformation;
/**
* Unit tests for {@link SimpleGemfireRepository}.
* Unit Tests for {@link SimpleGemfireRepository}.
*
* @author John Blum
* @see org.junit.Rule
* @see java.util.function.Function
* @see java.util.stream.Stream
* @see org.junit.Test
* @see org.mockito.Mockito
* @see org.apache.geode.cache.Cache
* @see org.apache.geode.cache.Region
* @see org.springframework.data.gemfire.GemfireTemplate
* @see org.springframework.data.gemfire.repository.Wrapper
* @see org.springframework.data.gemfire.repository.support.SimpleGemfireRepository
* @see org.springframework.data.repository.core.EntityInformation
* @since 1.4.5
*/
@SuppressWarnings("unchecked")
@SuppressWarnings({ "rawtypes", "unchecked" })
public class SimpleGemfireRepositoryUnitTests {
@Rule
public ExpectedException exception = ExpectedException.none();
protected Map<Long, Animal> asMap(Iterable<Animal> animals) {
Map<Long, Animal> animalMap = new HashMap<>();
for (Animal animal : animals) {
@@ -89,14 +90,20 @@ public class SimpleGemfireRepositoryUnitTests {
}
protected Animal newAnimal(String name) {
Animal animal = new Animal();
animal.setName(name);
return animal;
}
protected Animal newAnimal(Long id, String name) {
Animal animal = newAnimal(name);
animal.setId(id);
return animal;
}
@@ -105,6 +112,7 @@ public class SimpleGemfireRepositoryUnitTests {
}
protected Cache mockCache(String name, boolean transactionExists) {
Cache mockCache = mock(Cache.class, String.format("%s.MockCache", name));
CacheTransactionManager mockCacheTransactionManager = mock(CacheTransactionManager.class,
@@ -117,13 +125,15 @@ public class SimpleGemfireRepositoryUnitTests {
}
protected EntityInformation<Animal, Long> mockEntityInformation() {
EntityInformation<Animal, Long> mockEntityInformation = mock(EntityInformation.class);
doAnswer(new Answer<Long>() {
private final AtomicLong idSequence = new AtomicLong(0L);
@Override
public Long answer(InvocationOnMock invocation) throws Throwable {
public Long answer(InvocationOnMock invocation) {
Animal argument = invocation.getArgument(0);
argument.setId(resolveId(argument.getId()));
return argument.getId();
@@ -133,6 +143,7 @@ public class SimpleGemfireRepositoryUnitTests {
private Long resolveId(Long id) {
return (id != null ? id : idSequence.incrementAndGet());
}
}).when(mockEntityInformation).getRequiredId(any(Animal.class));
return mockEntityInformation;
@@ -143,6 +154,7 @@ public class SimpleGemfireRepositoryUnitTests {
}
protected Region mockRegion(String name) {
Region mockRegion = mock(Region.class, String.format("%s.MockRegion", name));
when(mockRegion.getName()).thenReturn(name);
@@ -152,6 +164,7 @@ public class SimpleGemfireRepositoryUnitTests {
}
protected Region mockRegion(String name, Cache mockCache, DataPolicy dataPolicy) {
Region mockRegion = mockRegion(name);
when(mockRegion.getRegionService()).thenReturn(mockCache);
@@ -166,25 +179,54 @@ public class SimpleGemfireRepositoryUnitTests {
}
@Test
public void constructSimpleGemfireRepositoryWithNullTemplateThrowsIllegalArgumentException() {
exception.expect(IllegalArgumentException.class);
exception.expectCause(is(nullValue(Throwable.class)));
exception.expectMessage("Template must not be null");
public void constructsSimpleGemfireRepositorySuccessfully() {
new SimpleGemfireRepository<>(null, mockEntityInformation());
Region mockRegion = mock(Region.class);
GemfireTemplate template = spy(new GemfireTemplate(mockRegion));
EntityInformation mockEntityInformation = mock(EntityInformation.class);
SimpleGemfireRepository repository = new SimpleGemfireRepository(template, mockEntityInformation);
assertThat(repository).isNotNull();
verifyNoInteractions(template, mockRegion, mockEntityInformation);
}
@Test
public void constructSimpleGemfireRepositoryWithNullEntityInformationThrowsIllegalArgumentException() {
exception.expect(IllegalArgumentException.class);
exception.expectCause(is(nullValue(Throwable.class)));
exception.expectMessage("EntityInformation must not be null");
@Test(expected = IllegalArgumentException.class)
public void constructSimpleGemfireRepositoryWithNullTemplateThrowsIllegalArgumentException() {
new SimpleGemfireRepository<>(newGemfireTemplate(mockRegion()), null);
try {
new SimpleGemfireRepository<>(null, mockEntityInformation());
}
catch (IllegalArgumentException expected) {
assertThat(expected).hasMessage("GemfireTemplate must not be null");
assertThat(expected).hasNoCause();
throw expected;
}
}
@Test(expected = IllegalArgumentException.class)
public void constructSimpleGemfireRepositoryWithNullEntityInformationThrowsIllegalArgumentException() {
try {
new SimpleGemfireRepository<>(newGemfireTemplate(mockRegion()), null);
}
catch (IllegalArgumentException expected) {
assertThat(expected).hasMessage("EntityInformation must not be null");
assertThat(expected).hasNoCause();
throw expected;
}
}
@Test
public void saveEntityIsCorrect() {
Region<Long, Animal> mockRegion = mockRegion();
SimpleGemfireRepository<Animal, Long> repository =
@@ -237,21 +279,98 @@ public class SimpleGemfireRepositoryUnitTests {
@Test
public void countReturnsNumberOfRegionEntries() {
Region mockRegion = mockRegion("Example");
GemfireTemplate template = spy(newGemfireTemplate(mockRegion));
SelectResults mockSelectResults = mock(SelectResults.class);
Region mockRegion = mockRegion("Example");
GemfireTemplate template = spy(newGemfireTemplate(mockRegion));
doReturn(mockSelectResults).when(template).find(eq("SELECT count(*) FROM /Example"));
when(mockSelectResults.iterator()).thenReturn(Collections.singletonList(21).iterator());
SimpleGemfireRepository<Animal, Long> repository = new SimpleGemfireRepository<>(
template, mockEntityInformation());
SimpleGemfireRepository<Animal, Long> repository =
new SimpleGemfireRepository<>(template, mockEntityInformation());
assertThat(repository.count()).isEqualTo(21);
assertThat(repository).isNotNull();
assertThat(repository.count()).isEqualTo(21L);
verify(template, times(1)).getRegion();
verify(mockRegion, times(1)).getFullPath();
verify(template, times(1)).find(eq("SELECT count(*) FROM /Example"));
verify(mockSelectResults, times(1)).iterator();
verifyNoMoreInteractions(mockRegion, mockSelectResults, template);
}
@Test
public void countWhenSelectResultsAreNullIsNullSafeAndReturnsZero() {
Region mockRegion = mockRegion("Example");
GemfireTemplate template = spy(newGemfireTemplate(mockRegion));
doReturn(null).when(template).find(eq("SELECT count(*) FROM /Example"));
SimpleGemfireRepository<Animal, Long> repository =
new SimpleGemfireRepository<>(template, mockEntityInformation());
assertThat(repository).isNotNull();
assertThat(repository.count()).isEqualTo(0L);
verify(template, times(1)).getRegion();
verify(mockRegion, times(1)).getFullPath();
verify(template, times(1)).find(eq("SELECT count(*) FROM /Example"));
verifyNoMoreInteractions(mockRegion, template);
}
@Test
public void countWhenSelectResultsIteratorIsNullIsNullSafeAndReturnsZero() {
SelectResults mockSelectResults = mock(SelectResults.class);
Region mockRegion = mockRegion("Example");
GemfireTemplate template = spy(newGemfireTemplate(mockRegion));
doReturn(mockSelectResults).when(template).find(eq("SELECT count(*) FROM /Example"));
doReturn(null).when(mockSelectResults).iterator();
SimpleGemfireRepository<Animal, Long> repository =
new SimpleGemfireRepository<>(template, mockEntityInformation());
assertThat(repository).isNotNull();
assertThat(repository.count()).isEqualTo(0L);
verify(template, times(1)).getRegion();
verify(mockRegion, times(1)).getFullPath();
verify(template, times(1)).find(eq("SELECT count(*) FROM /Example"));
verify(mockSelectResults, times(1)).iterator();
verifyNoMoreInteractions(mockRegion, mockSelectResults, template);
}
@Test
public void countWhenSelectResultsIteratorIsEmptyReturnsZero() {
SelectResults mockSelectResults = mock(SelectResults.class);
Region mockRegion = mockRegion("Example");
GemfireTemplate template = spy(newGemfireTemplate(mockRegion));
doReturn(mockSelectResults).when(template).find(eq("SELECT count(*) FROM /Example"));
doReturn(Collections.emptyIterator()).when(mockSelectResults).iterator();
SimpleGemfireRepository<Animal, Long> repository =
new SimpleGemfireRepository<>(template, mockEntityInformation());
assertThat(repository).isNotNull();
assertThat(repository.count()).isEqualTo(0L);
verify(template, times(1)).getRegion();
verify(mockRegion, times(1)).getFullPath();
verify(template, times(1)).find(eq("SELECT count(*) FROM /Example"));
verify(mockSelectResults, times(1)).iterator();
verifyNoMoreInteractions(mockRegion, mockSelectResults, template);
}
@Test