Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NR-370072: Unit tests for application error reporting #381

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public interface SecurityIntrospector {
Set getResponseWriterHash();

Set getResponseOutStreamHash();

Log4JStrSubstitutor getLog4JStrSubstitutor();

SecurityMetaData getSecurityMetaData();
Expand All @@ -50,4 +51,6 @@ public interface SecurityIntrospector {
void clear();

int getRandomPort();

List<Exception> getApplicationRuntimeError();
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ public void clear() {
NewRelicSecurity.getAgent().getSecurityMetaData().clearCustomAttr();
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(Agent.OPERATIONS, new ArrayList<>());
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(Agent.EXIT_OPERATIONS, new ArrayList<>());
NewRelicSecurity.getAgent().getSecurityMetaData().addCustomAttribute(Agent.APPLICATION_RUNTIME_ERROR, new ArrayList<>());

SecurityMetaData meta = NewRelicSecurity.getAgent().getSecurityMetaData();
meta.setRequest(new HttpRequest());
Expand All @@ -158,4 +159,9 @@ public int getRandomPort() {
}
return port;
}

@Override
public List<Exception> getApplicationRuntimeError() {
return (List<Exception>) NewRelicSecurity.getAgent().getSecurityMetaData().getCustomAttribute(Agent.APPLICATION_RUNTIME_ERROR, List.class);
}
}
4 changes: 2 additions & 2 deletions instrumentation-security/akka-http-2.11_10.0.0/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ dependencies {
implementation("com.newrelic.agent.java:newrelic-api:${nrAPIVersion}")
implementation("com.newrelic.agent.java:newrelic-weaver-api:${nrAPIVersion}")
implementation("com.typesafe.akka:akka-http_2.12:10.0.0")
implementation("com.typesafe.akka:akka-stream_2.12:2.5.19")
implementation("com.typesafe.akka:akka-actor_2.12:2.5.19")
implementation("com.typesafe.akka:akka-stream_2.12:2.4.20")
implementation("com.typesafe.akka:akka-actor_2.12:2.4.20")
}

verifyInstrumentation {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package com.nr.agent.security.instrumentation.akka.http.core_10

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpHeader, HttpRequest, HttpResponse}
import akka.stream.ActorMaterializer
import akka.util.ByteString
import com.newrelic.agent.security.introspec.{InstrumentationTestConfig, SecurityInstrumentationTestRunner, SecurityIntrospector}
import com.newrelic.api.agent.Trace
import com.newrelic.api.agent.security.instrumentation.helpers.{GenericHelper, ServletHelper}
import com.newrelic.api.agent.security.schema.operation.RXSSOperation
import com.newrelic.api.agent.security.schema.{SecurityMetaData, VulnerabilityCaseType}
import com.newrelic.security.test.marker.{Java11IncompatibleTest, Java17IncompatibleTest}
import org.junit.experimental.categories.Category
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
import org.junit.{Assert, FixMethodOrder, Test}

import java.util.UUID
import scala.collection.JavaConverters
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt

@RunWith(classOf[SecurityInstrumentationTestRunner])
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
@InstrumentationTestConfig(includePrefixes = Array("akka", "scala"))
@Category(Array(classOf[Java11IncompatibleTest], classOf[Java17IncompatibleTest]))
class AkkaHttpCoreTest {

implicit val system: ActorSystem = ActorSystem()
implicit val materializer: ActorMaterializer = ActorMaterializer()

val akkaServer = new AkkaServer()
var port: Int = SecurityInstrumentationTestRunner.getIntrospector.getRandomPort
val baseUrl: String = "http://localhost:%s/".format(port)

val contentType: String = "text/plain"
val responseBody: String = "Hello, World!"
val requestBody: String = "Hurray!"

@Test
def syncHandlerAkkaServerTestWithAkkaServer(): Unit = {
val headerValue = String.valueOf(UUID.randomUUID)
val introspector: SecurityIntrospector = SecurityInstrumentationTestRunner.getIntrospector

val httpResponse = makeHttpRequest(headerValue)

// assertions
Assert.assertTrue("No operations detected", introspector.getOperations.size() > 0)
assertCSECHeaders( headers= httpResponse.headers, headerVal = headerValue)
val operations = introspector.getOperations
for (op <- JavaConverters.collectionAsScalaIterable(operations)){
op match {
case operation: RXSSOperation => assertRXSSOperation(operation)
case _ =>
}
}
assertMetaData(introspector.getSecurityMetaData)
}

@Trace(dispatcher = true, nameTransaction = true)
private def makeHttpRequest(header: String): HttpResponse = {
// start akka server & make request
akkaServer.start(port)

val response = Await.result(
Http().singleRequest(
HttpRequest(uri = baseUrl + header,
entity = HttpEntity.Strict.apply(ContentTypes.`text/plain(UTF-8)`, ByteString.fromString(requestBody)))),
new DurationInt(15).seconds)

akkaServer.stop()
response
}

private def assertCSECHeaders(headers: Seq[HttpHeader], headerVal: String): Unit = {
Assert.assertTrue(
String.format("%s CSEC header should be present", ServletHelper.CSEC_IAST_FUZZ_REQUEST_ID),
headers.exists(header => header.name().contains(ServletHelper.CSEC_IAST_FUZZ_REQUEST_ID))
)
Assert.assertTrue(
String.format("Invalid CSEC header value for: %s", ServletHelper.CSEC_IAST_FUZZ_REQUEST_ID),
headers.exists(header => header.value().contains(headerVal))
)

Assert.assertTrue(
String.format("%s CSEC header should be present", ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER),
headers.exists(header => header.name().contains(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER))
)
Assert.assertTrue(
String.format("Invalid CSEC header value for: %s", ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER),
headers.exists(header => header.value().contains(headerVal))
)

Assert.assertTrue(
String.format("%s CSEC header should be present", GenericHelper.CSEC_PARENT_ID),
headers.exists(header => header.name().contains(GenericHelper.CSEC_PARENT_ID))
)
Assert.assertTrue(
String.format("Invalid CSEC header value for: %s", GenericHelper.CSEC_PARENT_ID),
headers.exists(header => header.value().contains(headerVal))
)
}

private def assertRXSSOperation(operation: RXSSOperation): Unit = {
Assert.assertFalse("operation should not be empty", operation.isEmpty)
Assert.assertFalse("LowSeverityHook should be disabled", operation.isLowSeverityHook)
Assert.assertEquals("Invalid event category.", VulnerabilityCaseType.REFLECTED_XSS, operation.getCaseType)
Assert.assertEquals("Invalid executed method name.", "apply", operation.getMethodName)

Assert.assertFalse("request should not be empty", operation.getRequest.isEmpty)
Assert.assertEquals("Invalid response content-type.", operation.getRequest.getContentType, contentType)
Assert.assertEquals("Invalid responseBody.", operation.getRequest.getBody.toString, requestBody)
Assert.assertEquals("Invalid protocol.", operation.getRequest.getProtocol, "http")

Assert.assertFalse("response should not be empty", operation.getResponse.isEmpty)
Assert.assertEquals("Invalid response content-type.", operation.getResponse.getResponseContentType, contentType)
Assert.assertEquals("Invalid responseBody.", operation.getResponse.getResponseBody.toString, responseBody)
}

private def assertMetaData(metaData: SecurityMetaData): Unit = {
Assert.assertFalse("response should not be empty", metaData.getResponse.isEmpty)
Assert.assertEquals("Invalid response content-type.", metaData.getRequest.getContentType, contentType)
Assert.assertEquals("In valid responseBody.", metaData.getRequest.getBody.toString, requestBody)
Assert.assertFalse("response should not be empty", metaData.getRequest.isEmpty)
Assert.assertEquals("Invalid response content-type.", metaData.getResponse.getResponseContentType, contentType)
Assert.assertEquals("Invalid responseBody.", metaData.getResponse.getResponseBody.toString, responseBody)
Assert.assertEquals("Invalid protocol.", metaData.getRequest.getProtocol, "http")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.nr.agent.security.instrumentation.akka.http.core_10

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.server.{Directives, RequestContext}
import akka.stream.ActorMaterializer
import akka.util.Timeout
import com.newrelic.api.agent.security.instrumentation.helpers.{GenericHelper, ServletHelper}

import scala.concurrent.duration._
import scala.concurrent.{ExecutionContextExecutor, Future}
import scala.language.postfixOps

//how the akka http core docs' example sets up a server
class AkkaServer extends Directives {
implicit val system: ActorSystem = ActorSystem()
implicit val executor: ExecutionContextExecutor = system.dispatcher
implicit val materializer: ActorMaterializer = ActorMaterializer()
implicit val timeout: Timeout = 3 seconds

var serverSource: Future[Http.ServerBinding] = _
var bindingFuture: Future[Http.ServerBinding] = _

def start(port: Int): Unit = {
val route = path(Segment) { segment =>
get { ctx: RequestContext =>
ctx.complete(
HttpResponse.apply(entity = "Hello, World!",
headers = (scala.collection.immutable.Seq(RawHeader.apply(ServletHelper.CSEC_IAST_FUZZ_REQUEST_ID, segment),
RawHeader.apply(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER, segment),
RawHeader.apply(GenericHelper.CSEC_PARENT_ID, segment)))))
}
}
serverSource = Http().bindAndHandle(route, interface = "localhost", port)
}

def stop(): Unit = {
if (serverSource != null) {
serverSource.flatMap(_.unbind()).onComplete(_ => system.terminate())
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.nr.agent.security.instrumentation.exception;

import com.newrelic.agent.security.introspec.InstrumentationTestConfig;
import com.newrelic.agent.security.introspec.SecurityInstrumentationTestRunner;
import com.newrelic.agent.security.introspec.SecurityIntrospector;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;

@RunWith(SecurityInstrumentationTestRunner.class)
@InstrumentationTestConfig(includePrefixes = "java.lang.")
public class ExceptionTest {

@Test
// In this case single uncaughtException is invoked and therefore single Application Runtime error will be reported
public void testReportApplicationRuntimeError() {
Exception e = new Exception();
Thread.UncaughtExceptionHandler uncaughtException = new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
e.printStackTrace();
}
};
uncaughtException.uncaughtException(Thread.currentThread(), e);

SecurityIntrospector introspector = SecurityInstrumentationTestRunner.getIntrospector();
Assert.assertFalse(introspector.getApplicationRuntimeError().isEmpty());
Assert.assertEquals(introspector.getApplicationRuntimeError().get(0), e);
}

@Test
// In this case no uncaughtException is invoked and no Application Runtime error will be reported
public void testReportNoApplicationRuntimeError() {
Exception e = new Exception();
Thread.UncaughtExceptionHandler uncaughtException = new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
e.printStackTrace();
}
};

SecurityIntrospector introspector = SecurityInstrumentationTestRunner.getIntrospector();
Assert.assertTrue(introspector.getApplicationRuntimeError().isEmpty());
}

@Test
// In this case multiple uncaughtExceptions are invoked and therefore multiple Application Runtime errors will be reported
public void testReportApplicationRuntimeErrors() {
Exception e = new Exception();

Thread.UncaughtExceptionHandler uncaughtException = new Thread.UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
e.printStackTrace();
}
};
uncaughtException.uncaughtException(Thread.currentThread(), e);

Exception e1 = new Exception();
uncaughtException.uncaughtException(Thread.currentThread(), e1);

SecurityIntrospector introspector = SecurityInstrumentationTestRunner.getIntrospector();
Assert.assertFalse(introspector.getApplicationRuntimeError().isEmpty());
Assert.assertEquals(introspector.getApplicationRuntimeError().get(0), e);
Assert.assertEquals(introspector.getApplicationRuntimeError().get(1), e1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ public class Agent implements SecurityAgent {

public static final String OPERATIONS = "operations";
public static final String EXIT_OPERATIONS = "exit-operations";
public static final String APPLICATION_RUNTIME_ERROR = "application-runtime-error";
private static Agent instance;
private final IastDetectionCategory defaultIastDetectionCategory = new IastDetectionCategory();

private AgentPolicy policy = new AgentPolicy();

private static final Object lock = new Object();

private Map<Integer, SecurityMetaData> securityMetaDataMap = new ConcurrentHashMap<>();
private final Map<Integer, SecurityMetaData> securityMetaDataMap = new ConcurrentHashMap<>();

private java.net.URL agentJarURL;

Expand Down Expand Up @@ -115,6 +116,7 @@ public SecurityMetaData getSecurityMetaData() {
meta = new SecurityMetaData();
meta.addCustomAttribute(OPERATIONS, new ArrayList<AbstractOperation>());
meta.addCustomAttribute(EXIT_OPERATIONS, new ArrayList<AbstractOperation>());
meta.addCustomAttribute(APPLICATION_RUNTIME_ERROR, new ArrayList<Exception>());
securityMetaDataMap.put(tx.hashCode(), meta);
}
populateSecurityData(meta);
Expand Down Expand Up @@ -211,7 +213,10 @@ public String decryptAndVerify(String encryptedData, String hashVerifier) {

@Override
public void reportApplicationRuntimeError(SecurityMetaData securityMetaData, Throwable exception) {

if (this.getSecurityMetaData().getCustomAttribute(APPLICATION_RUNTIME_ERROR, List.class).contains(exception)) {
return;
}
this.getSecurityMetaData().getCustomAttribute(APPLICATION_RUNTIME_ERROR, List.class).add(exception);
}

@Override
Expand Down
Loading