diff --git a/import-control.xml b/import-control.xml new file mode 100644 index 00000000..a94b64ac --- /dev/null +++ b/import-control.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/resources/default_config.xml b/src/main/resources/default_config.xml index bc7c73f5..647736d2 100644 --- a/src/main/resources/default_config.xml +++ b/src/main/resources/default_config.xml @@ -309,4 +309,9 @@ + + + + + diff --git a/src/main/resources/import_control_1_4.dtd b/src/main/resources/import_control_1_4.dtd new file mode 100644 index 00000000..5d6b20e5 --- /dev/null +++ b/src/main/resources/import_control_1_4.dtd @@ -0,0 +1,107 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf index 67f8446c..25040bd3 100644 --- a/src/main/resources/reference.conf +++ b/src/main/resources/reference.conf @@ -411,3 +411,9 @@ disallow.case.brace.description = "Checks that braces aren't used in case clause throw.message = "Avoid using throw statements." throw.label = "No throw statements." throw.description = "Checks that throw is not used." + +import.control.message = "Import not allowed: ''{0}'' by ''{1}'' in package {2}" +import.control.label = "Control imports per file or package" +import.control.description = "Control imports per file or package" +import.control.file.label = "File configuring imports per file or package" +import.control.file.description = "File configuring imports per file or package" \ No newline at end of file diff --git a/src/main/resources/scalastyle_definition.xml b/src/main/resources/scalastyle_definition.xml index e95a9ac5..b1ba12a2 100644 --- a/src/main/resources/scalastyle_definition.xml +++ b/src/main/resources/scalastyle_definition.xml @@ -218,4 +218,9 @@ + + + + + diff --git a/src/main/resources/scalastyle_documentation.xml b/src/main/resources/scalastyle_documentation.xml index cb5ca6ea..c0553a26 100644 --- a/src/main/resources/scalastyle_documentation.xml +++ b/src/main/resources/scalastyle_documentation.xml @@ -905,4 +905,19 @@ To fix it, replace the (unicode operator) `⇒` with `=>`. ]]> + + + Allow fine-grained control over imports in files and packages. + + The import control specification is the exact same as Checkstyle's Import Control checker. See https://checkstyle.sourceforge.io/config_imports.html#ImportControl for more information. + + + + import-control.xml + + + ]]> + + diff --git a/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala b/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala index 036a535c..d2ca356e 100644 --- a/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala +++ b/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala @@ -16,26 +16,27 @@ package org.scalastyle.scalariform -import java.util.regex.Pattern +import _root_.scalariform.lexer.Tokens.{DOT, PACKAGE, VARID} +import _root_.scalariform.lexer.{MultiLineComment, Token, Whitespace} +import _root_.scalariform.parser.{ + AstNode, + BlockImportExpr, + CompilationUnit, + Expr, + ExprElement, + GeneralTokens, + ImportClause, + ImportSelectors +} +import org.scalastyle.scalariform.VisitorHelper.visit +import org.scalastyle.{FileSpec, Level, Lines, Message, PositionError, ScalariformChecker, ScalastyleError} +import java.util.regex.Pattern +import scala.annotation.tailrec import scala.collection.mutable.ListBuffer +import scala.util.Try import scala.util.matching.Regex - -import _root_.scalariform.lexer.MultiLineComment -import _root_.scalariform.lexer.Token -import _root_.scalariform.lexer.Whitespace -import _root_.scalariform.parser.AstNode -import _root_.scalariform.parser.BlockImportExpr -import _root_.scalariform.parser.CompilationUnit -import _root_.scalariform.parser.Expr -import _root_.scalariform.parser.ExprElement -import _root_.scalariform.parser.GeneralTokens -import _root_.scalariform.parser.ImportClause -import _root_.scalariform.parser.ImportSelectors -import org.scalastyle.PositionError -import org.scalastyle.ScalariformChecker -import org.scalastyle.ScalastyleError -import org.scalastyle.scalariform.VisitorHelper.visit +import scala.xml.{Attribute, Elem, InputSource, Node, Source, XML} // scalastyle:off multiple.string.literals @@ -422,3 +423,495 @@ class ImportOrderChecker extends ScalariformChecker { PositionError(offset, args.map(_.toString).toList, Some(this.errorKey + "." + errorKey)) } + +/** + * Style checker that controls what can be imported in each package. The following configuration parameters + * are available: + * + * - file: a reference to the import-control configuration + * + * Currently, this checker only looks at the top-level list of imports. + * + * Note that the file format is defined in a dtd (see 'import_control_1_4.dtd' in resources) and is the same + * as the Checkstyle ImportControl configuration. As such the same configuration can be used for Checkstyle + * and Scalastyle (useful for mixed projects). See + * https://checkstyle.sourceforge.io/config_imports.html#ImportControl for more details. + */ +class ImportControlChecker extends ScalariformChecker { + + import ImportControlConfig._ + + val errorKey: String = "import.control" + // split in part before ., (optional) middle part with trailing dot, last part after dot + private val importDecomposeRe = "(.*)\\.(.*\\.)*(.*)".r + private var config: Try[ImportControlConfig] = _ + + override def setParameters(parameters: Map[String, String]): Unit = { + super.setParameters(parameters) + + config = Try { + require(parameters.contains("file") || parameters.contains("inline")) + + val source = + parameters + .get("file") + .map { fileName => + val file = new java.io.File(fileName) + require(file.exists(), s"File '${file.getAbsolutePath}' does not exist.") + Source.fromFile(file) + } + .getOrElse(Source.fromString(parameters("inline"))) + apply(source) + } + } + + override def verify(ast: CompilationUnit): List[ScalastyleError] = + throw new UnsupportedOperationException("Wrong method called") + + override def verify[T <: FileSpec]( + file: T, + level: Level, + ast: CompilationUnit, + lines: Lines + ): List[Message[T]] = + verify(file.name, ast).map(p => toStyleError(file, p, level, lines)) + + def verify(fileName: String, ast: CompilationUnit): List[ScalastyleError] = { + val name = extractFileName(fileName) + + val packageName = getPackageName(ast) + + // Use config.get here, because we want to be able to see mistakes in configuration + val ruleSet = config.get.determineRulesSet(packageName, name) + + val CompilationUnit(statements, _) = ast + statements.immediateChildren + .flatMap { + case ImportClause(_, BlockImportExpr(prefix, ImportSelectors(_, first, others, _)), _, _) => + val basePackage = exprToText(prefix.contents) + (first :: others.map(_._2)).flatMap { + case s: Expr => Some((s.firstToken.offset, s"$basePackage${s.contents.head.tokens.head.text}")) + case _ => None + } + case n @ ImportClause(_, Expr(contents), _, _) => Some((n.firstToken.offset, exprToText(contents))) + case _ => Nil + } + .foldLeft((Map.empty[String, String], List.empty[PositionError])) { + case ( + (scope, errors), + (offset, importText @ importDecomposeRe(firstPart, middlePart, importedItem)) + ) => + // scope is necessary to allow situations like: + // import java.util + // import util.Collection + // In that case first `java.util` is validated and then `java.util.Collection` + val fullImport = + scope + .get(firstPart) + .map(result => s"$result.${Option(middlePart).getOrElse("")}$importedItem") + .getOrElse(importText) + val newScope = if (importedItem == "_") scope else scope + (importedItem -> fullImport) + val result = ruleSet.validateImport(fullImport) + val maybeError = + Some(PositionError(offset, args = List(fullImport, result.rule, result.packageDescription))) + .filterNot(_ => result.allowed) + (newScope, errors ::: maybeError.toList) + case (result, _) => result // invalid import; ignore + } + ._2 + } + + private def extractFileName(fileName: String): String = { + val fileNameWithoutExtension = { + val lastDotPosition = fileName.lastIndexOf('.') + if (lastDotPosition >= 0) fileName.take(lastDotPosition) + else fileName + } + fileNameWithoutExtension.drop( + fileNameWithoutExtension.lastIndexOf(java.io.File.separatorChar.toInt) + 1 + ) + } + + private def exprToText(contents: List[ExprElement]): String = + contents + .flatMap { + case GeneralTokens(toks) => toks.map(_.text) + case n: Any => throw new IllegalStateException(s"FIXME: unexpected expr child node $n") + } + .mkString("") + + private def getPackageName(ast: CompilationUnit): String = { + def isPartOfPackageName(t: Token): Boolean = (t.tokenType == DOT) || (t.tokenType == VARID) + + @annotation.tailrec + def getNextPackageName(tokens: List[Token]): (List[Token], List[Token]) = tokens match { + case Nil => (Nil, Nil) + case hd :: tail if hd.tokenType == PACKAGE => tail.span(isPartOfPackageName) + case l: Any => getNextPackageName(l.dropWhile(tok => tok.tokenType != PACKAGE)) + } + + @annotation.tailrec + def getPackageNameLoop(tokens: List[Token], myAccumulator: List[List[Token]]): List[List[Token]] = + getNextPackageName(tokens) match { + case (Nil, Nil) => myAccumulator.reverse // Return the result, but reverse since we gathered backward + case (Nil, remainder) => + getPackageNameLoop(remainder, myAccumulator) // Found package object - try again + case (l, remainder) => // add match to results, go look again + val pkgName = l.filter(tok => tok.tokenType != DOT) // Strip out the dots between varids + getPackageNameLoop(remainder, pkgName :: myAccumulator) + } + + val packageNames = getPackageNameLoop(ast.tokens, Nil) + packageNames.flatten.map(_.text).mkString(".") + } +} + +object ImportControlConfig { + import ImportRule._ + import SourceIdentifier._ + import StrategyOnMismatch._ + + def apply(source: InputSource): ImportControlConfig = { + val xml = XML.load(source) + readImportControlConfig(xml) match { + case Right(config) => config + case Left(msg) => throw new RuntimeException(s"Configuration problem found: $msg") + } + } + + private def readImportControlConfig(elem: Elem): Either[String, ImportControlConfig] = { + for { + pkg <- elem.attribute("pkg").map(_.text).toRight("'pkg' attribute is required on 'import-control' node") + _ <- validAttributes(elem, "pkg", "strategyOnMismatch", "regex") + strategyOnMismatch <- + readStrategyOnMismatch(elem, "disallowed") + .filterOrElse( + _ != DelegateToParent, + "'DelegateToParent' not allowed as 'strategyOnMismatch' for 'import-control' node" + ) + regex <- Right[String, Boolean](readBoolean(elem, "regex")) + rules <- readImportRules(elem) + sources <- readSources(elem) + } yield ImportControlConfig(pkg, strategyOnMismatch, regex, rules, sources) + } + + private def validAttributes(node: Node, attributes: String*): Either[String, Unit] = { + node.attributes.collect { + case attr: Attribute if !attributes.contains(attr.key) => attr.key + } match { + case Nil => Right(()) + case head :: Nil => Left(s"Attribute '$head' is not valid on node '${node.label}'") + case list => + Left(s"Attributes ${list.map(a => s"'$a'").mkString(", ")} are not valid on node '${node.label}'") + } + } + + private def readImportRules(node: Node): Either[String, List[ImportRule]] = { + def readImportRule(node: Node): Either[String, ImportRule] = { + val isPackageRule = node.attribute("pkg").isDefined + for { + name <- node + .attribute("pkg") + .orElse(node.attribute("class")) + .map(_.text) + .toRight(s"'pkg' or 'class' attribute required on '${node.label}' node") + _ <- + if (isPackageRule) validAttributes(node, "pkg", "exact-match", "local-only", "regex") + else validAttributes(node, "class", "local-only", "regex") + } yield { + val exactMatch = readBoolean(node, "exact-match") + val localOnly = readBoolean(node, "local-only") + val regex = readBoolean(node, "regex") + (node.label, isPackageRule) match { + case ("allow", true) => AllowPackage(name, exactMatch, localOnly, regex) + case ("allow", false) => AllowClass(name, localOnly, regex) + case ("disallow", true) => DisallowPackage(name, exactMatch, localOnly, regex) + case ("disallow", false) => DisallowClass(name, localOnly, regex) + case _ => throw new RuntimeException("Impossible option") + } + } + } + + node.child + .filter(c => c.label == "allow" || c.label == "disallow") + .map(readImportRule) + .toSeq + .sequence + } + + private def readSources(node: Node): Either[String, List[SourceIdentifier]] = { + def readSource(node: Node): Either[String, SourceIdentifier] = + for { + name <- node + .attribute("name") + .map(_.text) + .toRight(s"'name' attribute required on '${node.label}' node") + _ <- + if (node.label == "file") validAttributes(node, "name", "regex") + else validAttributes(node, "name", "strategyOnMismatch", "regex") + strategyOnMismatch <- readStrategyOnMismatch(node, "delegateToParent") + rules <- readImportRules(node) + sources <- readSources(node) + } yield { + val regex = readBoolean(node, "regex") + node.label match { + case "subpackage" => Subpackage(name, strategyOnMismatch, regex, rules, sources) + case "file" => File(name, regex, rules) + } + } + + // first process file sources to have file matches precede subpackage sources + ((node \ "file" map readSource) ++ (node \ "subpackage" map readSource)).sequence + } + + private def readBoolean(node: Node, label: String): Boolean = + node + .attribute(label) + .map(_.text) + .collect { + case "true" => true + case "false" => false + } + .getOrElse(true) + + implicit class SeqOfEitherOp[A, B](eithers: Seq[Either[A, B]]) { + def sequence: Either[A, List[B]] = + eithers.foldRight[Either[A, List[B]]](Right(List.empty[B])) { + case (Right(b), acc) => acc.map(b :: _) + case (Left(a), _) => Left(a) + } + } + + implicit class EitherOps[A, B](either: Either[A, B]) { + def map[C](fn: B => C): Either[A, C] = either.right.map(fn) + def flatMap[C >: A, D](fn: B => Either[C, D]): Either[C, D] = either.right.flatMap(fn) + def filterOrElse(predicate: B => Boolean, or: => A): Either[A, B] = either match { + case Right(b) if predicate(b) => Right(b) + case Right(_) => Left(or) + case Left(a) => Left(a) + } + } + + private def readStrategyOnMismatch(node: Node, default: String): Either[String, StrategyOnMismatch] = + node.attribute("strategyOnMismatch").map(_.text).getOrElse(default) match { + case "delegateToParent" => Right(DelegateToParent) + case "allowed" => Right(Allowed) + case "disallowed" => Right(Disallowed) + case other: String => + Left(s"Unsupported value '$other' for 'strategyOnMismatch' attribute on '${node.label}' node") + } + + sealed trait RuleContainer { + def rules: List[ImportRule] + + def strategyOnMismatch: StrategyOnMismatch + + override def toString: String = this match { + case s: ImportControlConfig => s"[root] '${s.pkg}'" + case s: Subpackage => s"[sub] '${s.name}'" + case s: SourceIdentifier.File => s"[file] '${s.name}'" + } + } + + case class MatchResult(ruleContainer: RuleContainer, exactMatch: Boolean) + + case class ImportValidationResult(rule: String, packageDescription: String, allowed: Boolean) + + case class RuleSet(results: List[MatchResult]) { + def validateImport(forImport: String): ImportValidationResult = { + case class ImportValidationResultInternal(rule: ImportRule, allowed: Boolean) + def checkAccess(result: MatchResult): Option[ImportValidationResultInternal] = { + result.ruleContainer.rules + .filter(rule => result.exactMatch || !rule.localOnly) + .flatMap(rule => rule.verifyImport(forImport).map(ImportValidationResultInternal(rule, _))) + .headOption + } + + def describeLocation(remainingResults: List[MatchResult]): String = + remainingResults.reverse.map(_.ruleContainer.toString).mkString("/") + + @tailrec + def loop(remainingResults: List[MatchResult]): ImportValidationResult = + remainingResults match { + case Nil => ImportValidationResult("no match with root package", "n/a", allowed = false) + case head :: tail => + checkAccess(head) match { + case Some(ImportValidationResultInternal(rule, allowed)) => + ImportValidationResult(rule.toString, describeLocation(remainingResults), allowed) + case None => + head.ruleContainer.strategyOnMismatch match { + case DelegateToParent => loop(tail) + case Disallowed => + ImportValidationResult( + "mismatch strategy", + describeLocation(remainingResults), + allowed = false + ) + case Allowed => + ImportValidationResult( + "mismatch strategy", + describeLocation(remainingResults), + allowed = true + ) + } + } + } + + loop(results) + } + + } + + sealed abstract class PackageMatcher( + name: String, + regex: Boolean, + private[scalariform] val childSources: List[SourceIdentifier] + ) extends RuleContainer { + private val nameRegex = { + val regexBase = if (regex) name else Regex.quote(name) + (s"(?:$regexBase)(?:\\.(.*))?").r + } + + def locateRules(packageName: String, fileName: String): Option[List[MatchResult]] = + packageName match { + case nameRegex(rest) => + val subpackage = Option(rest).getOrElse("") + val thisResult = MatchResult(this, rest == null) + Some( + childSources + .map(_.locateRules(subpackage, fileName)) + .collectFirst { case Some(list) => thisResult :: list } + .getOrElse(List(thisResult)) + ) + case _ => None + } + } + + object SourceIdentifier { + sealed trait SourceIdentifier extends RuleContainer { + def locateRules(packageName: String, fileName: String): Option[List[MatchResult]] + } + + case class Subpackage( + name: String, + strategyOnMismatch: StrategyOnMismatch = DelegateToParent, + regex: Boolean = true, + rules: List[ImportRule], + sourceIdentifiers: List[SourceIdentifier] + ) extends PackageMatcher(name, regex, sourceIdentifiers) + with SourceIdentifier + + case class File(name: String, regex: Boolean = true, rules: List[ImportRule]) extends SourceIdentifier { + private val nameRegex = { + val regexBase = if (regex) name else Regex.quote(name) + (s"(?:$regexBase)").r + } + + override val strategyOnMismatch: StrategyOnMismatch = DelegateToParent + + override def locateRules(packageName: String, fileName: String): Option[List[MatchResult]] = + Option(fileName) + .filter(_ => packageName.isEmpty) + .filter { + case nameRegex() => true + case _ => false + } + .map(_ => List(MatchResult(this, exactMatch = true))) + } + } + + object ImportRule { + sealed trait ImportRule { + val name: String + + val localOnly: Boolean + + val exactMatch: Boolean = true + + val regex: Boolean + + def verifyImport(forImport: String): Option[Boolean] + + override def toString: String = this match { + case rule: AllowPackage => s"allow rule for package '${rule.name}'" + case rule: DisallowPackage => s"disallow rule for package '${rule.name}'" + case rule: AllowClass => s"allow rule for class '${rule.name}'" + case rule: DisallowClass => s"disallow rule for class '${rule.name}'" + } + } + + sealed protected abstract class PackageImportRule(allow: Boolean) extends ImportRule { + private val nameRegex = { + val regexBase = if (regex) name else Regex.quote(name) + (s"(?:$regexBase)(?:\\..*)?(\\..*)?").r + } + + private def importMatch(forImport: String): Boolean = + forImport match { + case nameRegex(lastSegment) => !exactMatch || lastSegment == null + case _ => false + } + + override def verifyImport(forImport: String): Option[Boolean] = + if (importMatch(forImport)) Some(allow) else None + } + + sealed protected abstract class ClassImportRule(allow: Boolean) extends ImportRule { + private val nameRegex = { + val regexBase = if (regex) name else Regex.quote(name) + (s"(?:$regexBase)").r + } + + private def importMatch(forImport: String): Boolean = + forImport match { + case nameRegex() => true + case _ => false + } + + override def verifyImport(forImport: String): Option[Boolean] = + if (importMatch(forImport)) Some(allow) else None + } + + case class AllowPackage( + name: String, + override val exactMatch: Boolean = true, + localOnly: Boolean = true, + regex: Boolean = true + ) extends PackageImportRule(allow = true) + + case class AllowClass(name: String, localOnly: Boolean = true, regex: Boolean = true) + extends ClassImportRule(allow = true) + + case class DisallowPackage( + name: String, + override val exactMatch: Boolean = true, + localOnly: Boolean = true, + regex: Boolean = true + ) extends PackageImportRule(allow = false) + + case class DisallowClass(name: String, localOnly: Boolean = true, regex: Boolean = true) + extends ClassImportRule(allow = false) + } + + object StrategyOnMismatch { + sealed trait StrategyOnMismatch + + case object DelegateToParent extends StrategyOnMismatch + + case object Allowed extends StrategyOnMismatch + + case object Disallowed extends StrategyOnMismatch + } + + case class ImportControlConfig( + pkg: String, + strategyOnMismatch: StrategyOnMismatch = Disallowed, + regex: Boolean = true, + rules: List[ImportRule], + sourceIdentifiers: List[SourceIdentifier] + ) extends PackageMatcher(pkg, regex, sourceIdentifiers) { + + def determineRulesSet(packageName: String, fileName: String): RuleSet = + RuleSet(locateRules(packageName, fileName).getOrElse(Nil).reverse) + } +} diff --git a/src/test/scala/org/scalastyle/file/CheckerTest.scala b/src/test/scala/org/scalastyle/file/CheckerTest.scala index 49957782..a607dffd 100644 --- a/src/test/scala/org/scalastyle/file/CheckerTest.scala +++ b/src/test/scala/org/scalastyle/file/CheckerTest.scala @@ -32,9 +32,7 @@ trait CheckerTest { protected val key: String protected val classUnderTest: Class[_ <: Checker[_]] - object NullFileSpec extends FileSpec { - def name(): String = "" - } + case class CustomFileSpec(name: String) extends FileSpec protected def assertErrors[T <: FileSpec]( expected: List[Message[T]], @@ -42,7 +40,8 @@ trait CheckerTest { params: Map[String, String] = Map(), customMessage: Option[String] = None, commentFilter: Boolean = true, - customId: Option[String] = None + customId: Option[String] = None, + customFileName: String = "" ) = { val classes = List( ConfigurationChecker(classUnderTest.getName(), WarningLevel, true, params, customMessage, customId) @@ -50,22 +49,38 @@ trait CheckerTest { val configuration = ScalastyleConfiguration("", commentFilter, classes) assertEquals( expected.mkString("\n"), - new CheckerUtils().verifySource(configuration, classes, NullFileSpec, source).mkString("\n") + new CheckerUtils() + .verifySource(configuration, classes, CustomFileSpec(customFileName), source) + .mkString("\n") ) } - protected def fileError(args: List[String] = List(), customMessage: Option[String] = None) = - StyleError(NullFileSpec, classUnderTest, key, WarningLevel, args, None, None, customMessage) - protected def lineError(line: Int, args: List[String] = List()) = - StyleError(NullFileSpec, classUnderTest, key, WarningLevel, args, Some(line), None) + protected def fileError( + args: List[String] = List(), + customMessage: Option[String] = None, + customFileName: String = "" + ) = + StyleError( + CustomFileSpec(customFileName), + classUnderTest, + key, + WarningLevel, + args, + None, + None, + customMessage + ) + protected def lineError(line: Int, args: List[String] = List(), customFileName: String = "") = + StyleError(CustomFileSpec(customFileName), classUnderTest, key, WarningLevel, args, Some(line), None) protected def columnError( line: Int, column: Int, args: List[String] = List(), - errorKey: Option[String] = None + errorKey: Option[String] = None, + customFileName: String = "" ) = StyleError( - NullFileSpec, + CustomFileSpec(customFileName), classUnderTest, errorKey.getOrElse(key), WarningLevel, diff --git a/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala b/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala index 7413f27b..bf8a39b6 100644 --- a/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala +++ b/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala @@ -20,6 +20,10 @@ import org.junit.Assert.assertTrue import org.junit.Test import org.scalastyle.file.CheckerTest import org.scalatestplus.junit.AssertionsForJUnit +import org.scalastyle.scalariform.ImportControlConfig.{MatchResult, PackageMatcher, RuleSet, SourceIdentifier} +import org.scalastyle.scalariform.ImportControlConfig.SourceIdentifier.{SourceIdentifier, Subpackage} + +import scala.xml.Source // scalastyle:off magic.number multiple.string.literals @@ -478,3 +482,179 @@ class ImportOrderCheckerTest extends AssertionsForJUnit with CheckerTest { private def errorKey(subkey: String): Option[String] = Some(key + "." + subkey) } + +class ImportControlCheckerTest extends AssertionsForJUnit with CheckerTest { + + import org.scalatest.matchers.should.Matchers._ + + val key = "import.control" + val classUnderTest = classOf[ImportControlChecker] + + @Test def testRuleSetMatching(): Unit = { + import TestTools.MatcherResultOps + val config = ImportControlConfig(Source.fromString(""" + | + | + | + | + | + | + | + | + | + |""".stripMargin)) + + val rootPartialMatch = MatchResult(config, exactMatch = false) + val rootExactMatch = MatchResult(config, exactMatch = true) + val exact = true + config.determineRulesSet("com.mismatch", "MyFile") shouldBe RuleSet(Nil) + config.determineRulesSet("com.test", "MyFile") shouldBe RuleSet( + List(rootPartialMatch.file("MyFile"), rootExactMatch) + ) + config.determineRulesSet("com.test", "OtherFile") shouldBe RuleSet(List(rootExactMatch)) + config.determineRulesSet("com.test.submismatch", "MyFile") shouldBe RuleSet(List(rootPartialMatch)) + config.determineRulesSet("com.test.subpackage", "SomeFile") shouldBe RuleSet( + List(rootPartialMatch.subpackage("subpack.*", exact), rootPartialMatch) + ) + config.determineRulesSet("com.test.subpackage", "MyFile") shouldBe RuleSet( + List( + rootPartialMatch.subpackage("subpack.*").file("MyFile"), + rootPartialMatch.subpackage("subpack.*", exact), + rootPartialMatch + ) + ) + config.determineRulesSet("com.test.subpackage.subsub", "SomeFile") shouldBe RuleSet( + List(rootPartialMatch.subpackage("subpack.*", exact), rootPartialMatch) + ) + config.determineRulesSet("com.testpackage", "MyFile") shouldBe RuleSet(Nil) + config.determineRulesSet("com.test.sub", "SomeFile") shouldBe RuleSet( + List(rootPartialMatch.subpackage("sub", exact), rootPartialMatch) + ) + config.determineRulesSet("com.test.sub", "MyFile") shouldBe RuleSet( + List( + rootPartialMatch.subpackage("sub").file("MyFile"), + rootPartialMatch.subpackage("sub", exact), + rootPartialMatch + ) + ) + } + + @Test def testImportMatching(): Unit = { + import TestTools.RuleSetOp + val config = ImportControlConfig( + Source.fromString(""" + | + | + | + | + | + | + | + | + | + | + | + | + | + | + | + | + | + |""".stripMargin) + ) + + config.determineRulesSet("com.mismatch", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe false + config.determineRulesSet("com.test", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe true + config.determineRulesSet("com.test", "MyFile").isAllowed("com.disallowed.SomeClass") shouldBe false + config.determineRulesSet("com.test", "MyFile").isAllowed("com.other.SomeClass") shouldBe true + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.ext") shouldBe true + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.ext.AClass") shouldBe true + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.ext.sub.AClass") shouldBe true + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.disallowed.AClass") shouldBe false + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.disallowed.SomeClass") shouldBe true + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.disallowed.OtherClass") shouldBe false + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe true + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.whatever.SomeClass") shouldBe false + config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.other.allowed.SomeClass") shouldBe false + config.determineRulesSet("com.test", "MyFile").isAllowed("com.other.allowed.SomeClass") shouldBe true + config.determineRulesSet("com.test", "SomeFile").isAllowed("com.other.allowed.SomeClass") shouldBe true + config.determineRulesSet("com.test.subpack", "MyFile").isAllowed("com.whatever.SomeClass") shouldBe true + config + .determineRulesSet("com.test.subpack.sub", "MyFile") + .isAllowed("com.whatever.SomeClass") shouldBe true + config.determineRulesSet("com.test.testasub", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe true + config.determineRulesSet("com.test.test.sub", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe false + } + + @Test def testWithSource(): Unit = { + val source = + """package com + |package test + | + |import com.allowed.MyClass + |import com.allowed.{Class1, Class2 => Renamed} + |import com.allowed.sub._ + |import com.disallowed.SomeClass + |import com.disallowed.{Class1, Class2 => Renamed} + |import com.disallowed.sub._ + |import com.disallowed.SomeClass._ + |import com.allowed + |import allowed.MyOtherClass + |import com.disallowed + |import disallowed.MyOtherClass + |""".stripMargin + + val config = + """ + | + | + | + |""".stripMargin + + val expected = List( + columnError(7, 0, args = List("com.disallowed.SomeClass", "mismatch strategy", "[root] 'com.test'")), + columnError(8, 23, args = List("com.disallowed.Class1", "mismatch strategy", "[root] 'com.test'")), + columnError(8, 31, args = List("com.disallowed.Class2", "mismatch strategy", "[root] 'com.test'")), + columnError(9, 0, args = List("com.disallowed.sub._", "mismatch strategy", "[root] 'com.test'")), + columnError(10, 0, args = List("com.disallowed.SomeClass._", "mismatch strategy", "[root] 'com.test'")), + columnError(13, 0, args = List("com.disallowed", "mismatch strategy", "[root] 'com.test'")), + columnError(14, 0, args = List("com.disallowed.MyOtherClass", "mismatch strategy", "[root] 'com.test'")) + ) + + val sep = java.io.File.separator + assertErrors( + expected, + source, + params = Map("inline" -> config), + customFileName = s"com${sep}test${sep}Test.scala" + ) + } + + object TestTools { + implicit class MatcherResultOps(matchResult: MatchResult) { + private def childSources: List[SourceIdentifier] = + matchResult.ruleContainer match { + case packageMatcher: PackageMatcher => packageMatcher.childSources + case _ => Nil + } + + def subpackage(name: String, exactMatch: Boolean = false): MatchResult = + childSources + .collectFirst { + case s: Subpackage if s.name == name => MatchResult(s, exactMatch) + } + .getOrElse(fail(s"Can't find subpackage with name '$name'")) + + def file(name: String): MatchResult = + childSources + .collectFirst { + case f: SourceIdentifier.File if f.name == name => MatchResult(f, exactMatch = true) + } + .getOrElse(fail(s"Can't find file with name '$name'")) + } + + implicit class RuleSetOp(ruleSet: RuleSet) { + def isAllowed(forImport: String): Boolean = ruleSet.validateImport(forImport).allowed + } + } +}