Creating a Custom JUnit5 Extension: A Comprehensive Guide
JUnit 5 has the concept of extensions – a powerful feature that can be used to extend the behavior of JUnit tests. In this blog post, we will dive deep into the world of JUnit 5 extensions, and we’ll walk through the process of creating a custom JUnit5 extension.
What is a JUnit5 Extension?
In JUnit5, an Extension is defined as any class or interface that extends or implements any of the extension APIs provided by JUnit5. The Extension model in JUnit5 is a replacement for JUnit4’s runners and rules. It enables developers to write plugins that can handle test classes or methods, allowing for customized behavior during test executions.
To create an extension, you need to implement one or several extension APIs provided by JUnit 5. Some of these APIs include `BeforeAllCallback`, `BeforeEachCallback`, `AfterEachCallback`, `AfterAllCallback`, `TestExecutionExceptionHandler`, etc.
Step 1: Setting up the Extension
Let’s create a simple JUnit5 extension to understand the life cycle of a custom extension.
Write extension class
Firstly, we’ll need to define our extension class. Next the class implements the `BeforeAllCallback`, `AfterAllCallback`, `BeforeEachCallback`, and `AfterEachCallback` interfaces to hook into test execution life cycle. For simplicity, we’ll name our class SampleCustomExtension
public class SampleCustomExtension implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback {
// ...
}
Code language: Java (java)
Step 2: Implementing Callbacks
Next, we’ll implement the four callback methods:
- beforeAll – This method will be called once before all tests
- afterAll – This method will be called once after all tests have finished
- beforeEach – This method will be called before each test
- afterEach – This method will be called after each test
public class SampleCustomExtension implements BeforeAllCallback, AfterAllCallback,
BeforeEachCallback, AfterEachCallback, BeforeTestExecutionCallback, AfterTestExecutionCallback {
private static final Logger logger = LoggerFactory.getLogger(SampleCustomExtension.class);
@Override
public void afterAll(ExtensionContext extensionContext) throws Exception {
logger.info("After all tests");
}
@Override
public void beforeAll(ExtensionContext extensionContext) throws Exception {
logger.info("Before all test");
}
@Override
public void afterEach(ExtensionContext extensionContext) throws Exception {
logger.info("After each test");
}
@Override
public void beforeEach(ExtensionContext extensionContext) throws Exception {
logger.info("Before each test");
}
@Override
public void afterTestExecution(ExtensionContext extensionContext) throws Exception {
logger.info("After each test - beforeTestExecution");
}
@Override
public void beforeTestExecution(ExtensionContext extensionContext) throws Exception {
logger.info("Before each test - beforeTestExecution");
}
}
Code language: Java (java)
Now write a test class and use the above extension with @ExtendWith annotation
@ExtendWith(SampleCustomExtension.class)
public class SampleTest {
private static final Logger logger = LoggerFactory.getLogger(SampleTest.class);
@BeforeEach
public void beforeEach(){
logger.info("before each test in test class");
}
@AfterEach
public void afterEach(){
logger.info("after each test in test class");
}
@Test
public void sampleTest() {
logger.info("in sample test");
}
@Test
public void sampleTest2() {
logger.info("in sample test - 2");
}
}
Code language: Java (java)
If above run above test case you will get following output in the logs.
[INFO] Running dev.fullstackcode.currencyexchange.SampleTest
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- Before all test
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- Before each test
INFO dev.fullstackcode.currencyexchange.SampleTest -- before each test in test class
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- Before each test - beforeTestExecution
INFO dev.fullstackcode.currencyexchange.SampleTest -- in sample test - 2
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- After each test - beforeTestExecution
INFO dev.fullstackcode.currencyexchange.SampleTest -- after each test in test class
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- After each test
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- Before each test
INFO dev.fullstackcode.currencyexchange.SampleTest -- before each test in test class
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- Before each test - beforeTestExecution
INFO dev.fullstackcode.currencyexchange.SampleTest -- in sample test
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- After each test - beforeTestExecution
INFO dev.fullstackcode.currencyexchange.SampleTest -- after each test in test class
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- After each test
INFO dev.fullstackcode.currencyexchange.junit.extension.SampleCustomExtension -- After all tests
Code language: Java (java)
Difference between BeforeEachCallback and BeforeTestExecutionCallback
If you look at the above logs, there are 2 extension methods running before and after test is executed
BeforeEachCallback and BeforeTestExecutionCallback
2 extension methods running after test is executed
AfterEachCallback and AfterTestExecutionCallback
While they run before and after each test case,
If you need to implement callbacks that are invoked around @BeforeEach
and @AfterEach
methods, implement BeforeEachCallback
and AfterEachCallback
.
If you need to implement callbacks that are invoked around each test methods annotated with @Test then use BeforeTestExecutionCallback
and AfterTestExecutionCallback
methods.
These extension methods define the APIs for Extensions that wish to add behavior that will be executed immediately before and immediately after a test method is executed, respectively. As such, these callbacks are well suited for timing, tracing, and similar use cases
Building JUnit5 Extension to Log Method Execution Timing
Let’s create a simple JUnit5 extension that will log the execution time of each test method.
Interface Implementation
First, we need to implement the `BeforeTestExecutionCallback` and `AfterTestExecutionCallback` interfaces to hook into the test execution lifecycle:
public class TimingExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {
}
Code language: Java (java)
State Storage
Next, we need to store the start time of the test. We’ll use the `Store` in the `ExtensionContext` to hold this:
@Override
public void beforeTestExecution(ExtensionContext context) {
getStore(context).put(context.getRequiredTestMethod(), System.currentTimeMillis());
}
private Store getStore(ExtensionContext context) {
return context.getStore(ExtensionContext.Namespace.create(getClass(), context));
}
Code language: Java (java)
Logging the Execution Time
Then, we’ll calculate and log the execution time in the `afterTestExecution` callback:
@Override
public void afterTestExecution(ExtensionContext context) {
long startTime = getStore(context).remove(context.getRequiredTestMethod(), long.class);
long duration = System.currentTimeMillis() - startTime;
System.out.println("Test method [" + context.getRequiredTestMethod() + "] took " + duration + " ms.");
}
Code language: Java (java)
public class TimingExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {
@Override
public void beforeTestExecution(ExtensionContext context) {
getStore(context).put(context.getRequiredTestMethod(), System.currentTimeMillis());
}
private Store getStore(ExtensionContext context) {
return context.getStore(ExtensionContext.Namespace.create(getClass(), context));
}
@Override
public void afterTestExecution(ExtensionContext context) throws Exception {
long startTime = getStore(context).remove(context.getRequiredTestMethod(), long.class);
long duration = System.currentTimeMillis() - startTime;
System.out.println("Test method [" + context.getRequiredTestMethod() + "] took " + duration + " ms.");
}
}
Code language: Java (java)
Using the Extension
To use the extension, we simply annotate our test class or method with `@ExtendWith(TimingExtension.class)`
@ExtendWith(TimingExtension.class)
public class ProgrammaticWireMockTest {
// your tests here
}
Code language: Java (java)
In test logs, you can see statements like below showing methos execution time.
Test method [public void dev.fullstackcode.currencyexchange.ProgrammaticWireMockTest.testConvertByCurrencyCode() throws java.lang.Exception] took 329 ms.
Test method [public void dev.fullstackcode.currencyexchange.ProgrammaticWireMockTest.testConvertByCurrencyCodeWhenConversionFoundWithGivenCurrencyCode() throws java.lang.Exception] took 17 ms.
Test method [void dev.fullstackcode.currencyexchange.ProgrammaticWireMockTest.contextLoads()] took 0 ms.
Test method [public void dev.fullstackcode.currencyexchange.ProgrammaticWireMockTest.testCurrencyConversionByCountryWhenNoCountryFoundWithGivenCode() throws java.lang.Exception] took 8 ms.
Code language: Java (java)
Sharing the code
In general to share the code we write the code in Base class and extend that in every class.
There are few disadvantages
- The code in base class shared every class even when the testing class may not require that logic.
- As Java supports single level inheritance , you can not extend any other class.
Extensions acts as centralized place to share the code.
Let’s take an example where service depends on external service. For the the external service we need to run mock server and give some response back.
Writing the extension to start Wiremock server
Let’s write extension which starts and stops the Wiremock server.
public class WiremockServerExtension implements BeforeAllCallback, AfterAllCallback , ParameterResolver {
private static final WireMockServer currencyServer =
new WireMockServer(WireMockConfiguration.options().port(4141));
private static final WireMockServer countryServer =
new WireMockServer(WireMockConfiguration.options().port(4040));
@Override
public void afterAll(ExtensionContext extensionContext) throws Exception {
if(currencyServer.isRunning()) {
currencyServer.stop();
}
if(countryServer.isRunning()) {
countryServer.stop();
}
}
@Override
public void beforeAll(ExtensionContext extensionContext) throws Exception {
if(!currencyServer.isRunning()) {
currencyServer.start();
}
if(!countryServer.isRunning()) {
countryServer.start();
}
}
@Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
return parameterContext.getParameter().getType().equals(WireMockServer.class);
}
@Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
if(parameterContext.getParameter().getType().equals(WireMockServer.class) && parameterContext.getParameter().getName().equals("currencyServer")) {
return currencyServer;
}
if(parameterContext.getParameter().getType().equals(WireMockServer.class) && parameterContext.getParameter().getName().equals("countryServer")) {
return countryServer;
}
return null;
}
public static WireMockServer currencyWireMock() {
return currencyServer;
}
public static WireMockServer countryWireMock() {
return countryServer;
}
}
Code language: Java (java)
Then we can use the Wiremock extension server like below
@SpringBootTest
@AutoConfigureMockMvc
@ExtendWith(TimingExtension.class)
@ExtendWith(WiremockServerExtension.class)
public class ProgrammaticWireMockExtensionTest {
@Autowired
private MockMvc mockMvc;
@Test
void contextLoads() {
}
@Test
public void testCurrencyConversionByCountry(WireMockServer currencyServer,WireMockServer countryServer) throws Exception {
String url = String.format("%s/v3.1/name/%s", currencyServer.baseUrl(),"%s");
System.out.println("url :" + url);
String currencyCode = "jpy";
String currencyConversionUrl = String.format("/gh/fawazahmed0/currency-api@1/latest/currencies/usd/%s.json",currencyCode.toLowerCase());
String country = "japan";
String countryurl = String.format("/v3.1/name/%s",country.toLowerCase());
String response = """
[{"name":{"common":"Japan","official":"Japan","nativeName":{"jpn":{"official":"日本","common":"日本"}}},"tld":[".jp",".みんな"],"cca2":"JP","ccn3":"392","cca3":"JPN","cioc":"JPN","independent":true,"status":"officially-assigned","unMember":true,"currencies":{"JPY":{"name":"Japanese yen","symbol":"¥"}},"idd":{"root":"+8","suffixes":["1"]},"capital":["Tokyo"],"altSpellings":["JP","Nippon","Nihon"],"region":"Asia","subregion":"Eastern Asia","languages":{"jpn":"Japanese"},"translations":{"ara":{"official":"اليابان","common":"اليابان"},"bre":{"official":"Japan","common":"Japan"},"ces":{"official":"Japonsko","common":"Japonsko"},"cym":{"official":"Japan","common":"Japan"},"deu":{"official":"Japan","common":"Japan"},"est":{"official":"Jaapan","common":"Jaapan"},"fin":{"official":"Japani","common":"Japani"},"fra":{"official":"Japon","common":"Japon"},"hrv":{"official":"Japan","common":"Japan"},"hun":{"official":"Japán","common":"Japán"},"ita":{"official":"Giappone","common":"Giappone"},"jpn":{"official":"日本","common":"日本"},"kor":{"official":"일본국","common":"일본"},"nld":{"official":"Japan","common":"Japan"},"per":{"official":"ژاپن","common":"ژاپن"},"pol":{"official":"Japonia","common":"Japonia"},"por":{"official":"Japão","common":"Japão"},"rus":{"official":"Япония","common":"Япония"},"slk":{"official":"Japonsko","common":"Japonsko"},"spa":{"official":"Japón","common":"Japón"},"srp":{"official":"Јапан","common":"Јапан"},"swe":{"official":"Japan","common":"Japan"},"tur":{"official":"Japonya","common":"Japonya"},"urd":{"official":"جاپان","common":"جاپان"},"zho":{"official":"日本国","common":"日本"}},"latlng":[36.0,138.0],"landlocked":false,"area":377930.0,"demonyms":{"eng":{"f":"Japanese","m":"Japanese"},"fra":{"f":"Japonaise","m":"Japonais"}},"flag":"\\uD83C\\uDDEF\\uD83C\\uDDF5","maps":{"googleMaps":"https://goo.gl/maps/NGTLSCSrA8bMrvnX9","openStreetMaps":"https://www.openstreetmap.org/relation/382313"},"population":125836021,"gini":{"2013":32.9},"fifa":"JPN","car":{"signs":["J"],"side":"left"},"timezones":["UTC+09:00"],"continents":["Asia"],"flags":{"png":"https://flagcdn.com/w320/jp.png","svg":"https://flagcdn.com/jp.svg","alt":"The flag of Japan features a crimson-red circle at the center of a white field."},"coatOfArms":{"png":"https://mainfacts.com/media/images/coats_of_arms/jp.png","svg":"https://mainfacts.com/media/images/coats_of_arms/jp.svg"},"startOfWeek":"monday","capitalInfo":{"latlng":[35.68,139.75]},"postalCode":{"format":"###-####","regex":"^(\\\\d{7})$"}}]""";
String response1 = """
{
"date": "2023-05-19",
"jpy": 138.544529
}""";
currencyServer.stubFor(WireMock.get(urlMatching(currencyConversionUrl)).willReturn(WireMock.aResponse().withStatus(200).withBody(response1)));
countryServer.stubFor(WireMock.get(urlMatching(countryurl)).willReturn(WireMock.aResponse().withStatus(200).withBody(response)));
this.mockMvc.perform(get("/currencyCodeVersion/country/{countryCode}", country))
.andExpect(status().isOk());
}
@DynamicPropertySource
public static void properties(DynamicPropertyRegistry registry) {
registry.add("country.url",()-> String.format("%s/v3.1/name/%s",
WiremockServerExtension.countryWireMock().baseUrl(),"%s"));
registry.add("currencyconverter.url",()->String.format("%s/gh/fawazahmed0/currency-api@1" +
"/latest/currencies/usd/%s.json", WiremockServerExtension.currencyWireMock().baseUrl(),"%s"));
}
}
Code language: Java (java)
In above extension , I have implemented interface ParameterResolver.
ParameterResolver
defines the API forExtensions
that wish to dynamically resolve arguments for parameters at runtime.- If a constructor for a test class or a
@Test
,@BeforeEach
,@AfterEach
,@BeforeAll
, or@AfterAll
method declares a parameter, an argument for the parameter must be resolved at runtime by aParameterResolver
.
In resolveParameter method, we are resolving parameter based on Type and name.
ParameterResolver helps to access the fields defined in the extension without explicitly accessing them.
In below test class, you can observe that we are accessing currency server and country servers from extension just defining them as parameters in test method without explicitly accessing them
@SpringBootTest
@AutoConfigureMockMvc
@ExtendWith(TimingExtension.class)
@ExtendWith(WiremockServerExtension.class)
public class ProgrammaticWireMockExtensionTest {
@Autowired
private MockMvc mockMvc;
@Test
void contextLoads() {
}
@Test
public void testCurrencyConversionByCountry(WireMockServer currencyServer,WireMockServer countryServer) throws Exception {
String url = String.format("%s/v3.1/name/%s", currencyServer.baseUrl(),"%s");
System.out.println("url :" + url);
String currencyCode = "jpy";
String currencyConversionUrl = String.format("/gh/fawazahmed0/currency-api@1/latest/currencies/usd/%s.json",currencyCode.toLowerCase());
String country = "japan";
String countryurl = String.format("/v3.1/name/%s",country.toLowerCase());
String response = """
[{"name":{"common":"Japan","official":"Japan","nativeName":{"jpn":{"official":"日本","common":"日本"}}},"tld":[".jp",".みんな"],"cca2":"JP","ccn3":"392","cca3":"JPN","cioc":"JPN","independent":true,"status":"officially-assigned","unMember":true,"currencies":{"JPY":{"name":"Japanese yen","symbol":"¥"}},"idd":{"root":"+8","suffixes":["1"]},"capital":["Tokyo"],"altSpellings":["JP","Nippon","Nihon"],"region":"Asia","subregion":"Eastern Asia","languages":{"jpn":"Japanese"},"translations":{"ara":{"official":"اليابان","common":"اليابان"},"bre":{"official":"Japan","common":"Japan"},"ces":{"official":"Japonsko","common":"Japonsko"},"cym":{"official":"Japan","common":"Japan"},"deu":{"official":"Japan","common":"Japan"},"est":{"official":"Jaapan","common":"Jaapan"},"fin":{"official":"Japani","common":"Japani"},"fra":{"official":"Japon","common":"Japon"},"hrv":{"official":"Japan","common":"Japan"},"hun":{"official":"Japán","common":"Japán"},"ita":{"official":"Giappone","common":"Giappone"},"jpn":{"official":"日本","common":"日本"},"kor":{"official":"일본국","common":"일본"},"nld":{"official":"Japan","common":"Japan"},"per":{"official":"ژاپن","common":"ژاپن"},"pol":{"official":"Japonia","common":"Japonia"},"por":{"official":"Japão","common":"Japão"},"rus":{"official":"Япония","common":"Япония"},"slk":{"official":"Japonsko","common":"Japonsko"},"spa":{"official":"Japón","common":"Japón"},"srp":{"official":"Јапан","common":"Јапан"},"swe":{"official":"Japan","common":"Japan"},"tur":{"official":"Japonya","common":"Japonya"},"urd":{"official":"جاپان","common":"جاپان"},"zho":{"official":"日本国","common":"日本"}},"latlng":[36.0,138.0],"landlocked":false,"area":377930.0,"demonyms":{"eng":{"f":"Japanese","m":"Japanese"},"fra":{"f":"Japonaise","m":"Japonais"}},"flag":"\\uD83C\\uDDEF\\uD83C\\uDDF5","maps":{"googleMaps":"https://goo.gl/maps/NGTLSCSrA8bMrvnX9","openStreetMaps":"https://www.openstreetmap.org/relation/382313"},"population":125836021,"gini":{"2013":32.9},"fifa":"JPN","car":{"signs":["J"],"side":"left"},"timezones":["UTC+09:00"],"continents":["Asia"],"flags":{"png":"https://flagcdn.com/w320/jp.png","svg":"https://flagcdn.com/jp.svg","alt":"The flag of Japan features a crimson-red circle at the center of a white field."},"coatOfArms":{"png":"https://mainfacts.com/media/images/coats_of_arms/jp.png","svg":"https://mainfacts.com/media/images/coats_of_arms/jp.svg"},"startOfWeek":"monday","capitalInfo":{"latlng":[35.68,139.75]},"postalCode":{"format":"###-####","regex":"^(\\\\d{7})$"}}]""";
String response1 = """
{
"date": "2023-05-19",
"jpy": 138.544529
}""";
currencyServer.stubFor(WireMock.get(urlMatching(currencyConversionUrl)).willReturn(WireMock.aResponse().withStatus(200).withBody(response1)));
countryServer.stubFor(WireMock.get(urlMatching(countryurl)).willReturn(WireMock.aResponse().withStatus(200).withBody(response)));
this.mockMvc.perform(get("/currencyCodeVersion/country/{countryCode}", country))
.andExpect(status().isOk());
}
}
Code language: Java (java)
Note: Wiremock library comes with JUnit5 extension, above example shows our own extension.
Using Extension in Integration Test
Not just unit cases, you can also use JUnit extension in integration test also
Following code shows an example on Spring Boot Integration which is making use of Testcontainers library moved the Testcontainers code to an extension.
@Testcontainers
public class WiremockTestContainersExtension implements BeforeAllCallback, AfterAllCallback {
static final WireMockContainer WIRE_MOCK_CONTAINER = new WireMockContainer("wiremock/wiremock:2.35.0-alpine")
.withCopyFileToContainer(MountableFile.forClasspathResource("extensions"), "/var/wiremock/extensions")
.withStubMappingForClasspathResource("stubs")//// loads all *.json files in resources/stubs/
.withCommand("--verbose")
.withCommand("--gloabl-response-templating")
.withCommand("--extensions dev.fullstackcode.wiremock.transformer.ResponseTemplateTransformerExtension,com.ninecookies.wiremock.extensions.JsonBodyTransformer");
@Override
public void afterAll(ExtensionContext extensionContext) throws Exception {
WIRE_MOCK_CONTAINER.stop();
}
@Override
public void beforeAll(ExtensionContext extensionContext) throws Exception {
WIRE_MOCK_CONTAINER.start();
}
}
Code language: Java (java)
Using the extenstion
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@Slf4j
@ExtendWith(WiremockTestContainersExtension.class)
class WiremockExample22IT {
@Autowired
private TestRestTemplate testRestTemplate;
@LocalServerPort
private int serverPort;
private static final String ROOT_URL = "http://localhost:";
@Test
void contextLoads() {
}
@Test
public void testCreatePost() throws Exception {
Post p = new Post();
p.setBody("body");
p.setTitle("Title");
ResponseEntity<Post> response = testRestTemplate.postForEntity(ROOT_URL + serverPort + "/posts", p, Post.class);
assertEquals(200, response.getStatusCode().value());
assertNotNull( response.getBody().getId());
assertEquals(p.getBody(), response.getBody().getBody());
assertEquals(p.getTitle(), response.getBody().getTitle());
}
@DynamicPropertySource
public static void properties(DynamicPropertyRegistry registry) {
registry.add("json.mock.api",
() -> WiremockTestContainersExtension.WIRE_MOCK_CONTAINER.getHttpUrl() + "/posts");
}
}
Code language: Java (java)
Conclusion
In this blog post, we’ve explored JUnit 5 extensions and created simple to complex extensions. We’ve demonstrated how extensions can be used to interact with the test lifecycle and manage resources.
Extensions in JUnit 5 provide a powerful and flexible way to extend the behavior of your tests. They enable you to encapsulate setup and teardown logic in reusable components, and keep your test classes clean and focused on testing.