diff --git a/import-control.xml b/import-control.xml
new file mode 100644
index 00000000..a94b64ac
--- /dev/null
+++ b/import-control.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/main/resources/default_config.xml b/src/main/resources/default_config.xml
index bc7c73f5..647736d2 100644
--- a/src/main/resources/default_config.xml
+++ b/src/main/resources/default_config.xml
@@ -309,4 +309,9 @@
+
+
+
+
+
diff --git a/src/main/resources/import_control_1_4.dtd b/src/main/resources/import_control_1_4.dtd
new file mode 100644
index 00000000..5d6b20e5
--- /dev/null
+++ b/src/main/resources/import_control_1_4.dtd
@@ -0,0 +1,107 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf
index 67f8446c..25040bd3 100644
--- a/src/main/resources/reference.conf
+++ b/src/main/resources/reference.conf
@@ -411,3 +411,9 @@ disallow.case.brace.description = "Checks that braces aren't used in case clause
throw.message = "Avoid using throw statements."
throw.label = "No throw statements."
throw.description = "Checks that throw is not used."
+
+import.control.message = "Import not allowed: ''{0}'' by ''{1}'' in package {2}"
+import.control.label = "Control imports per file or package"
+import.control.description = "Control imports per file or package"
+import.control.file.label = "File configuring imports per file or package"
+import.control.file.description = "File configuring imports per file or package"
\ No newline at end of file
diff --git a/src/main/resources/scalastyle_definition.xml b/src/main/resources/scalastyle_definition.xml
index e95a9ac5..b1ba12a2 100644
--- a/src/main/resources/scalastyle_definition.xml
+++ b/src/main/resources/scalastyle_definition.xml
@@ -218,4 +218,9 @@
+
+
+
+
+
diff --git a/src/main/resources/scalastyle_documentation.xml b/src/main/resources/scalastyle_documentation.xml
index cb5ca6ea..c0553a26 100644
--- a/src/main/resources/scalastyle_documentation.xml
+++ b/src/main/resources/scalastyle_documentation.xml
@@ -905,4 +905,19 @@ To fix it, replace the (unicode operator) `⇒` with `=>`.
]]>
+
+
+ Allow fine-grained control over imports in files and packages.
+
+ The import control specification is the exact same as Checkstyle's Import Control checker. See https://checkstyle.sourceforge.io/config_imports.html#ImportControl for more information.
+
+
+
+ import-control.xml
+
+
+ ]]>
+
+
diff --git a/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala b/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala
index 036a535c..d2ca356e 100644
--- a/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala
+++ b/src/main/scala/org/scalastyle/scalariform/ImportsChecker.scala
@@ -16,26 +16,27 @@
package org.scalastyle.scalariform
-import java.util.regex.Pattern
+import _root_.scalariform.lexer.Tokens.{DOT, PACKAGE, VARID}
+import _root_.scalariform.lexer.{MultiLineComment, Token, Whitespace}
+import _root_.scalariform.parser.{
+ AstNode,
+ BlockImportExpr,
+ CompilationUnit,
+ Expr,
+ ExprElement,
+ GeneralTokens,
+ ImportClause,
+ ImportSelectors
+}
+import org.scalastyle.scalariform.VisitorHelper.visit
+import org.scalastyle.{FileSpec, Level, Lines, Message, PositionError, ScalariformChecker, ScalastyleError}
+import java.util.regex.Pattern
+import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
+import scala.util.Try
import scala.util.matching.Regex
-
-import _root_.scalariform.lexer.MultiLineComment
-import _root_.scalariform.lexer.Token
-import _root_.scalariform.lexer.Whitespace
-import _root_.scalariform.parser.AstNode
-import _root_.scalariform.parser.BlockImportExpr
-import _root_.scalariform.parser.CompilationUnit
-import _root_.scalariform.parser.Expr
-import _root_.scalariform.parser.ExprElement
-import _root_.scalariform.parser.GeneralTokens
-import _root_.scalariform.parser.ImportClause
-import _root_.scalariform.parser.ImportSelectors
-import org.scalastyle.PositionError
-import org.scalastyle.ScalariformChecker
-import org.scalastyle.ScalastyleError
-import org.scalastyle.scalariform.VisitorHelper.visit
+import scala.xml.{Attribute, Elem, InputSource, Node, Source, XML}
// scalastyle:off multiple.string.literals
@@ -422,3 +423,495 @@ class ImportOrderChecker extends ScalariformChecker {
PositionError(offset, args.map(_.toString).toList, Some(this.errorKey + "." + errorKey))
}
+
+/**
+ * Style checker that controls what can be imported in each package. The following configuration parameters
+ * are available:
+ *
+ * - file: a reference to the import-control configuration
+ *
+ * Currently, this checker only looks at the top-level list of imports.
+ *
+ * Note that the file format is defined in a dtd (see 'import_control_1_4.dtd' in resources) and is the same
+ * as the Checkstyle ImportControl configuration. As such the same configuration can be used for Checkstyle
+ * and Scalastyle (useful for mixed projects). See
+ * https://checkstyle.sourceforge.io/config_imports.html#ImportControl for more details.
+ */
+class ImportControlChecker extends ScalariformChecker {
+
+ import ImportControlConfig._
+
+ val errorKey: String = "import.control"
+ // split in part before ., (optional) middle part with trailing dot, last part after dot
+ private val importDecomposeRe = "(.*)\\.(.*\\.)*(.*)".r
+ private var config: Try[ImportControlConfig] = _
+
+ override def setParameters(parameters: Map[String, String]): Unit = {
+ super.setParameters(parameters)
+
+ config = Try {
+ require(parameters.contains("file") || parameters.contains("inline"))
+
+ val source =
+ parameters
+ .get("file")
+ .map { fileName =>
+ val file = new java.io.File(fileName)
+ require(file.exists(), s"File '${file.getAbsolutePath}' does not exist.")
+ Source.fromFile(file)
+ }
+ .getOrElse(Source.fromString(parameters("inline")))
+ apply(source)
+ }
+ }
+
+ override def verify(ast: CompilationUnit): List[ScalastyleError] =
+ throw new UnsupportedOperationException("Wrong method called")
+
+ override def verify[T <: FileSpec](
+ file: T,
+ level: Level,
+ ast: CompilationUnit,
+ lines: Lines
+ ): List[Message[T]] =
+ verify(file.name, ast).map(p => toStyleError(file, p, level, lines))
+
+ def verify(fileName: String, ast: CompilationUnit): List[ScalastyleError] = {
+ val name = extractFileName(fileName)
+
+ val packageName = getPackageName(ast)
+
+ // Use config.get here, because we want to be able to see mistakes in configuration
+ val ruleSet = config.get.determineRulesSet(packageName, name)
+
+ val CompilationUnit(statements, _) = ast
+ statements.immediateChildren
+ .flatMap {
+ case ImportClause(_, BlockImportExpr(prefix, ImportSelectors(_, first, others, _)), _, _) =>
+ val basePackage = exprToText(prefix.contents)
+ (first :: others.map(_._2)).flatMap {
+ case s: Expr => Some((s.firstToken.offset, s"$basePackage${s.contents.head.tokens.head.text}"))
+ case _ => None
+ }
+ case n @ ImportClause(_, Expr(contents), _, _) => Some((n.firstToken.offset, exprToText(contents)))
+ case _ => Nil
+ }
+ .foldLeft((Map.empty[String, String], List.empty[PositionError])) {
+ case (
+ (scope, errors),
+ (offset, importText @ importDecomposeRe(firstPart, middlePart, importedItem))
+ ) =>
+ // scope is necessary to allow situations like:
+ // import java.util
+ // import util.Collection
+ // In that case first `java.util` is validated and then `java.util.Collection`
+ val fullImport =
+ scope
+ .get(firstPart)
+ .map(result => s"$result.${Option(middlePart).getOrElse("")}$importedItem")
+ .getOrElse(importText)
+ val newScope = if (importedItem == "_") scope else scope + (importedItem -> fullImport)
+ val result = ruleSet.validateImport(fullImport)
+ val maybeError =
+ Some(PositionError(offset, args = List(fullImport, result.rule, result.packageDescription)))
+ .filterNot(_ => result.allowed)
+ (newScope, errors ::: maybeError.toList)
+ case (result, _) => result // invalid import; ignore
+ }
+ ._2
+ }
+
+ private def extractFileName(fileName: String): String = {
+ val fileNameWithoutExtension = {
+ val lastDotPosition = fileName.lastIndexOf('.')
+ if (lastDotPosition >= 0) fileName.take(lastDotPosition)
+ else fileName
+ }
+ fileNameWithoutExtension.drop(
+ fileNameWithoutExtension.lastIndexOf(java.io.File.separatorChar.toInt) + 1
+ )
+ }
+
+ private def exprToText(contents: List[ExprElement]): String =
+ contents
+ .flatMap {
+ case GeneralTokens(toks) => toks.map(_.text)
+ case n: Any => throw new IllegalStateException(s"FIXME: unexpected expr child node $n")
+ }
+ .mkString("")
+
+ private def getPackageName(ast: CompilationUnit): String = {
+ def isPartOfPackageName(t: Token): Boolean = (t.tokenType == DOT) || (t.tokenType == VARID)
+
+ @annotation.tailrec
+ def getNextPackageName(tokens: List[Token]): (List[Token], List[Token]) = tokens match {
+ case Nil => (Nil, Nil)
+ case hd :: tail if hd.tokenType == PACKAGE => tail.span(isPartOfPackageName)
+ case l: Any => getNextPackageName(l.dropWhile(tok => tok.tokenType != PACKAGE))
+ }
+
+ @annotation.tailrec
+ def getPackageNameLoop(tokens: List[Token], myAccumulator: List[List[Token]]): List[List[Token]] =
+ getNextPackageName(tokens) match {
+ case (Nil, Nil) => myAccumulator.reverse // Return the result, but reverse since we gathered backward
+ case (Nil, remainder) =>
+ getPackageNameLoop(remainder, myAccumulator) // Found package object - try again
+ case (l, remainder) => // add match to results, go look again
+ val pkgName = l.filter(tok => tok.tokenType != DOT) // Strip out the dots between varids
+ getPackageNameLoop(remainder, pkgName :: myAccumulator)
+ }
+
+ val packageNames = getPackageNameLoop(ast.tokens, Nil)
+ packageNames.flatten.map(_.text).mkString(".")
+ }
+}
+
+object ImportControlConfig {
+ import ImportRule._
+ import SourceIdentifier._
+ import StrategyOnMismatch._
+
+ def apply(source: InputSource): ImportControlConfig = {
+ val xml = XML.load(source)
+ readImportControlConfig(xml) match {
+ case Right(config) => config
+ case Left(msg) => throw new RuntimeException(s"Configuration problem found: $msg")
+ }
+ }
+
+ private def readImportControlConfig(elem: Elem): Either[String, ImportControlConfig] = {
+ for {
+ pkg <- elem.attribute("pkg").map(_.text).toRight("'pkg' attribute is required on 'import-control' node")
+ _ <- validAttributes(elem, "pkg", "strategyOnMismatch", "regex")
+ strategyOnMismatch <-
+ readStrategyOnMismatch(elem, "disallowed")
+ .filterOrElse(
+ _ != DelegateToParent,
+ "'DelegateToParent' not allowed as 'strategyOnMismatch' for 'import-control' node"
+ )
+ regex <- Right[String, Boolean](readBoolean(elem, "regex"))
+ rules <- readImportRules(elem)
+ sources <- readSources(elem)
+ } yield ImportControlConfig(pkg, strategyOnMismatch, regex, rules, sources)
+ }
+
+ private def validAttributes(node: Node, attributes: String*): Either[String, Unit] = {
+ node.attributes.collect {
+ case attr: Attribute if !attributes.contains(attr.key) => attr.key
+ } match {
+ case Nil => Right(())
+ case head :: Nil => Left(s"Attribute '$head' is not valid on node '${node.label}'")
+ case list =>
+ Left(s"Attributes ${list.map(a => s"'$a'").mkString(", ")} are not valid on node '${node.label}'")
+ }
+ }
+
+ private def readImportRules(node: Node): Either[String, List[ImportRule]] = {
+ def readImportRule(node: Node): Either[String, ImportRule] = {
+ val isPackageRule = node.attribute("pkg").isDefined
+ for {
+ name <- node
+ .attribute("pkg")
+ .orElse(node.attribute("class"))
+ .map(_.text)
+ .toRight(s"'pkg' or 'class' attribute required on '${node.label}' node")
+ _ <-
+ if (isPackageRule) validAttributes(node, "pkg", "exact-match", "local-only", "regex")
+ else validAttributes(node, "class", "local-only", "regex")
+ } yield {
+ val exactMatch = readBoolean(node, "exact-match")
+ val localOnly = readBoolean(node, "local-only")
+ val regex = readBoolean(node, "regex")
+ (node.label, isPackageRule) match {
+ case ("allow", true) => AllowPackage(name, exactMatch, localOnly, regex)
+ case ("allow", false) => AllowClass(name, localOnly, regex)
+ case ("disallow", true) => DisallowPackage(name, exactMatch, localOnly, regex)
+ case ("disallow", false) => DisallowClass(name, localOnly, regex)
+ case _ => throw new RuntimeException("Impossible option")
+ }
+ }
+ }
+
+ node.child
+ .filter(c => c.label == "allow" || c.label == "disallow")
+ .map(readImportRule)
+ .toSeq
+ .sequence
+ }
+
+ private def readSources(node: Node): Either[String, List[SourceIdentifier]] = {
+ def readSource(node: Node): Either[String, SourceIdentifier] =
+ for {
+ name <- node
+ .attribute("name")
+ .map(_.text)
+ .toRight(s"'name' attribute required on '${node.label}' node")
+ _ <-
+ if (node.label == "file") validAttributes(node, "name", "regex")
+ else validAttributes(node, "name", "strategyOnMismatch", "regex")
+ strategyOnMismatch <- readStrategyOnMismatch(node, "delegateToParent")
+ rules <- readImportRules(node)
+ sources <- readSources(node)
+ } yield {
+ val regex = readBoolean(node, "regex")
+ node.label match {
+ case "subpackage" => Subpackage(name, strategyOnMismatch, regex, rules, sources)
+ case "file" => File(name, regex, rules)
+ }
+ }
+
+ // first process file sources to have file matches precede subpackage sources
+ ((node \ "file" map readSource) ++ (node \ "subpackage" map readSource)).sequence
+ }
+
+ private def readBoolean(node: Node, label: String): Boolean =
+ node
+ .attribute(label)
+ .map(_.text)
+ .collect {
+ case "true" => true
+ case "false" => false
+ }
+ .getOrElse(true)
+
+ implicit class SeqOfEitherOp[A, B](eithers: Seq[Either[A, B]]) {
+ def sequence: Either[A, List[B]] =
+ eithers.foldRight[Either[A, List[B]]](Right(List.empty[B])) {
+ case (Right(b), acc) => acc.map(b :: _)
+ case (Left(a), _) => Left(a)
+ }
+ }
+
+ implicit class EitherOps[A, B](either: Either[A, B]) {
+ def map[C](fn: B => C): Either[A, C] = either.right.map(fn)
+ def flatMap[C >: A, D](fn: B => Either[C, D]): Either[C, D] = either.right.flatMap(fn)
+ def filterOrElse(predicate: B => Boolean, or: => A): Either[A, B] = either match {
+ case Right(b) if predicate(b) => Right(b)
+ case Right(_) => Left(or)
+ case Left(a) => Left(a)
+ }
+ }
+
+ private def readStrategyOnMismatch(node: Node, default: String): Either[String, StrategyOnMismatch] =
+ node.attribute("strategyOnMismatch").map(_.text).getOrElse(default) match {
+ case "delegateToParent" => Right(DelegateToParent)
+ case "allowed" => Right(Allowed)
+ case "disallowed" => Right(Disallowed)
+ case other: String =>
+ Left(s"Unsupported value '$other' for 'strategyOnMismatch' attribute on '${node.label}' node")
+ }
+
+ sealed trait RuleContainer {
+ def rules: List[ImportRule]
+
+ def strategyOnMismatch: StrategyOnMismatch
+
+ override def toString: String = this match {
+ case s: ImportControlConfig => s"[root] '${s.pkg}'"
+ case s: Subpackage => s"[sub] '${s.name}'"
+ case s: SourceIdentifier.File => s"[file] '${s.name}'"
+ }
+ }
+
+ case class MatchResult(ruleContainer: RuleContainer, exactMatch: Boolean)
+
+ case class ImportValidationResult(rule: String, packageDescription: String, allowed: Boolean)
+
+ case class RuleSet(results: List[MatchResult]) {
+ def validateImport(forImport: String): ImportValidationResult = {
+ case class ImportValidationResultInternal(rule: ImportRule, allowed: Boolean)
+ def checkAccess(result: MatchResult): Option[ImportValidationResultInternal] = {
+ result.ruleContainer.rules
+ .filter(rule => result.exactMatch || !rule.localOnly)
+ .flatMap(rule => rule.verifyImport(forImport).map(ImportValidationResultInternal(rule, _)))
+ .headOption
+ }
+
+ def describeLocation(remainingResults: List[MatchResult]): String =
+ remainingResults.reverse.map(_.ruleContainer.toString).mkString("/")
+
+ @tailrec
+ def loop(remainingResults: List[MatchResult]): ImportValidationResult =
+ remainingResults match {
+ case Nil => ImportValidationResult("no match with root package", "n/a", allowed = false)
+ case head :: tail =>
+ checkAccess(head) match {
+ case Some(ImportValidationResultInternal(rule, allowed)) =>
+ ImportValidationResult(rule.toString, describeLocation(remainingResults), allowed)
+ case None =>
+ head.ruleContainer.strategyOnMismatch match {
+ case DelegateToParent => loop(tail)
+ case Disallowed =>
+ ImportValidationResult(
+ "mismatch strategy",
+ describeLocation(remainingResults),
+ allowed = false
+ )
+ case Allowed =>
+ ImportValidationResult(
+ "mismatch strategy",
+ describeLocation(remainingResults),
+ allowed = true
+ )
+ }
+ }
+ }
+
+ loop(results)
+ }
+
+ }
+
+ sealed abstract class PackageMatcher(
+ name: String,
+ regex: Boolean,
+ private[scalariform] val childSources: List[SourceIdentifier]
+ ) extends RuleContainer {
+ private val nameRegex = {
+ val regexBase = if (regex) name else Regex.quote(name)
+ (s"(?:$regexBase)(?:\\.(.*))?").r
+ }
+
+ def locateRules(packageName: String, fileName: String): Option[List[MatchResult]] =
+ packageName match {
+ case nameRegex(rest) =>
+ val subpackage = Option(rest).getOrElse("")
+ val thisResult = MatchResult(this, rest == null)
+ Some(
+ childSources
+ .map(_.locateRules(subpackage, fileName))
+ .collectFirst { case Some(list) => thisResult :: list }
+ .getOrElse(List(thisResult))
+ )
+ case _ => None
+ }
+ }
+
+ object SourceIdentifier {
+ sealed trait SourceIdentifier extends RuleContainer {
+ def locateRules(packageName: String, fileName: String): Option[List[MatchResult]]
+ }
+
+ case class Subpackage(
+ name: String,
+ strategyOnMismatch: StrategyOnMismatch = DelegateToParent,
+ regex: Boolean = true,
+ rules: List[ImportRule],
+ sourceIdentifiers: List[SourceIdentifier]
+ ) extends PackageMatcher(name, regex, sourceIdentifiers)
+ with SourceIdentifier
+
+ case class File(name: String, regex: Boolean = true, rules: List[ImportRule]) extends SourceIdentifier {
+ private val nameRegex = {
+ val regexBase = if (regex) name else Regex.quote(name)
+ (s"(?:$regexBase)").r
+ }
+
+ override val strategyOnMismatch: StrategyOnMismatch = DelegateToParent
+
+ override def locateRules(packageName: String, fileName: String): Option[List[MatchResult]] =
+ Option(fileName)
+ .filter(_ => packageName.isEmpty)
+ .filter {
+ case nameRegex() => true
+ case _ => false
+ }
+ .map(_ => List(MatchResult(this, exactMatch = true)))
+ }
+ }
+
+ object ImportRule {
+ sealed trait ImportRule {
+ val name: String
+
+ val localOnly: Boolean
+
+ val exactMatch: Boolean = true
+
+ val regex: Boolean
+
+ def verifyImport(forImport: String): Option[Boolean]
+
+ override def toString: String = this match {
+ case rule: AllowPackage => s"allow rule for package '${rule.name}'"
+ case rule: DisallowPackage => s"disallow rule for package '${rule.name}'"
+ case rule: AllowClass => s"allow rule for class '${rule.name}'"
+ case rule: DisallowClass => s"disallow rule for class '${rule.name}'"
+ }
+ }
+
+ sealed protected abstract class PackageImportRule(allow: Boolean) extends ImportRule {
+ private val nameRegex = {
+ val regexBase = if (regex) name else Regex.quote(name)
+ (s"(?:$regexBase)(?:\\..*)?(\\..*)?").r
+ }
+
+ private def importMatch(forImport: String): Boolean =
+ forImport match {
+ case nameRegex(lastSegment) => !exactMatch || lastSegment == null
+ case _ => false
+ }
+
+ override def verifyImport(forImport: String): Option[Boolean] =
+ if (importMatch(forImport)) Some(allow) else None
+ }
+
+ sealed protected abstract class ClassImportRule(allow: Boolean) extends ImportRule {
+ private val nameRegex = {
+ val regexBase = if (regex) name else Regex.quote(name)
+ (s"(?:$regexBase)").r
+ }
+
+ private def importMatch(forImport: String): Boolean =
+ forImport match {
+ case nameRegex() => true
+ case _ => false
+ }
+
+ override def verifyImport(forImport: String): Option[Boolean] =
+ if (importMatch(forImport)) Some(allow) else None
+ }
+
+ case class AllowPackage(
+ name: String,
+ override val exactMatch: Boolean = true,
+ localOnly: Boolean = true,
+ regex: Boolean = true
+ ) extends PackageImportRule(allow = true)
+
+ case class AllowClass(name: String, localOnly: Boolean = true, regex: Boolean = true)
+ extends ClassImportRule(allow = true)
+
+ case class DisallowPackage(
+ name: String,
+ override val exactMatch: Boolean = true,
+ localOnly: Boolean = true,
+ regex: Boolean = true
+ ) extends PackageImportRule(allow = false)
+
+ case class DisallowClass(name: String, localOnly: Boolean = true, regex: Boolean = true)
+ extends ClassImportRule(allow = false)
+ }
+
+ object StrategyOnMismatch {
+ sealed trait StrategyOnMismatch
+
+ case object DelegateToParent extends StrategyOnMismatch
+
+ case object Allowed extends StrategyOnMismatch
+
+ case object Disallowed extends StrategyOnMismatch
+ }
+
+ case class ImportControlConfig(
+ pkg: String,
+ strategyOnMismatch: StrategyOnMismatch = Disallowed,
+ regex: Boolean = true,
+ rules: List[ImportRule],
+ sourceIdentifiers: List[SourceIdentifier]
+ ) extends PackageMatcher(pkg, regex, sourceIdentifiers) {
+
+ def determineRulesSet(packageName: String, fileName: String): RuleSet =
+ RuleSet(locateRules(packageName, fileName).getOrElse(Nil).reverse)
+ }
+}
diff --git a/src/test/scala/org/scalastyle/file/CheckerTest.scala b/src/test/scala/org/scalastyle/file/CheckerTest.scala
index 49957782..a607dffd 100644
--- a/src/test/scala/org/scalastyle/file/CheckerTest.scala
+++ b/src/test/scala/org/scalastyle/file/CheckerTest.scala
@@ -32,9 +32,7 @@ trait CheckerTest {
protected val key: String
protected val classUnderTest: Class[_ <: Checker[_]]
- object NullFileSpec extends FileSpec {
- def name(): String = ""
- }
+ case class CustomFileSpec(name: String) extends FileSpec
protected def assertErrors[T <: FileSpec](
expected: List[Message[T]],
@@ -42,7 +40,8 @@ trait CheckerTest {
params: Map[String, String] = Map(),
customMessage: Option[String] = None,
commentFilter: Boolean = true,
- customId: Option[String] = None
+ customId: Option[String] = None,
+ customFileName: String = ""
) = {
val classes = List(
ConfigurationChecker(classUnderTest.getName(), WarningLevel, true, params, customMessage, customId)
@@ -50,22 +49,38 @@ trait CheckerTest {
val configuration = ScalastyleConfiguration("", commentFilter, classes)
assertEquals(
expected.mkString("\n"),
- new CheckerUtils().verifySource(configuration, classes, NullFileSpec, source).mkString("\n")
+ new CheckerUtils()
+ .verifySource(configuration, classes, CustomFileSpec(customFileName), source)
+ .mkString("\n")
)
}
- protected def fileError(args: List[String] = List(), customMessage: Option[String] = None) =
- StyleError(NullFileSpec, classUnderTest, key, WarningLevel, args, None, None, customMessage)
- protected def lineError(line: Int, args: List[String] = List()) =
- StyleError(NullFileSpec, classUnderTest, key, WarningLevel, args, Some(line), None)
+ protected def fileError(
+ args: List[String] = List(),
+ customMessage: Option[String] = None,
+ customFileName: String = ""
+ ) =
+ StyleError(
+ CustomFileSpec(customFileName),
+ classUnderTest,
+ key,
+ WarningLevel,
+ args,
+ None,
+ None,
+ customMessage
+ )
+ protected def lineError(line: Int, args: List[String] = List(), customFileName: String = "") =
+ StyleError(CustomFileSpec(customFileName), classUnderTest, key, WarningLevel, args, Some(line), None)
protected def columnError(
line: Int,
column: Int,
args: List[String] = List(),
- errorKey: Option[String] = None
+ errorKey: Option[String] = None,
+ customFileName: String = ""
) =
StyleError(
- NullFileSpec,
+ CustomFileSpec(customFileName),
classUnderTest,
errorKey.getOrElse(key),
WarningLevel,
diff --git a/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala b/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala
index 7413f27b..bf8a39b6 100644
--- a/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala
+++ b/src/test/scala/org/scalastyle/scalariform/ImportsCheckerTest.scala
@@ -20,6 +20,10 @@ import org.junit.Assert.assertTrue
import org.junit.Test
import org.scalastyle.file.CheckerTest
import org.scalatestplus.junit.AssertionsForJUnit
+import org.scalastyle.scalariform.ImportControlConfig.{MatchResult, PackageMatcher, RuleSet, SourceIdentifier}
+import org.scalastyle.scalariform.ImportControlConfig.SourceIdentifier.{SourceIdentifier, Subpackage}
+
+import scala.xml.Source
// scalastyle:off magic.number multiple.string.literals
@@ -478,3 +482,179 @@ class ImportOrderCheckerTest extends AssertionsForJUnit with CheckerTest {
private def errorKey(subkey: String): Option[String] = Some(key + "." + subkey)
}
+
+class ImportControlCheckerTest extends AssertionsForJUnit with CheckerTest {
+
+ import org.scalatest.matchers.should.Matchers._
+
+ val key = "import.control"
+ val classUnderTest = classOf[ImportControlChecker]
+
+ @Test def testRuleSetMatching(): Unit = {
+ import TestTools.MatcherResultOps
+ val config = ImportControlConfig(Source.fromString("""
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |""".stripMargin))
+
+ val rootPartialMatch = MatchResult(config, exactMatch = false)
+ val rootExactMatch = MatchResult(config, exactMatch = true)
+ val exact = true
+ config.determineRulesSet("com.mismatch", "MyFile") shouldBe RuleSet(Nil)
+ config.determineRulesSet("com.test", "MyFile") shouldBe RuleSet(
+ List(rootPartialMatch.file("MyFile"), rootExactMatch)
+ )
+ config.determineRulesSet("com.test", "OtherFile") shouldBe RuleSet(List(rootExactMatch))
+ config.determineRulesSet("com.test.submismatch", "MyFile") shouldBe RuleSet(List(rootPartialMatch))
+ config.determineRulesSet("com.test.subpackage", "SomeFile") shouldBe RuleSet(
+ List(rootPartialMatch.subpackage("subpack.*", exact), rootPartialMatch)
+ )
+ config.determineRulesSet("com.test.subpackage", "MyFile") shouldBe RuleSet(
+ List(
+ rootPartialMatch.subpackage("subpack.*").file("MyFile"),
+ rootPartialMatch.subpackage("subpack.*", exact),
+ rootPartialMatch
+ )
+ )
+ config.determineRulesSet("com.test.subpackage.subsub", "SomeFile") shouldBe RuleSet(
+ List(rootPartialMatch.subpackage("subpack.*", exact), rootPartialMatch)
+ )
+ config.determineRulesSet("com.testpackage", "MyFile") shouldBe RuleSet(Nil)
+ config.determineRulesSet("com.test.sub", "SomeFile") shouldBe RuleSet(
+ List(rootPartialMatch.subpackage("sub", exact), rootPartialMatch)
+ )
+ config.determineRulesSet("com.test.sub", "MyFile") shouldBe RuleSet(
+ List(
+ rootPartialMatch.subpackage("sub").file("MyFile"),
+ rootPartialMatch.subpackage("sub", exact),
+ rootPartialMatch
+ )
+ )
+ }
+
+ @Test def testImportMatching(): Unit = {
+ import TestTools.RuleSetOp
+ val config = ImportControlConfig(
+ Source.fromString("""
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |""".stripMargin)
+ )
+
+ config.determineRulesSet("com.mismatch", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe false
+ config.determineRulesSet("com.test", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe true
+ config.determineRulesSet("com.test", "MyFile").isAllowed("com.disallowed.SomeClass") shouldBe false
+ config.determineRulesSet("com.test", "MyFile").isAllowed("com.other.SomeClass") shouldBe true
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.ext") shouldBe true
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.ext.AClass") shouldBe true
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.ext.sub.AClass") shouldBe true
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.disallowed.AClass") shouldBe false
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.disallowed.SomeClass") shouldBe true
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.disallowed.OtherClass") shouldBe false
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe true
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.whatever.SomeClass") shouldBe false
+ config.determineRulesSet("com.test.sub", "MyFile").isAllowed("com.other.allowed.SomeClass") shouldBe false
+ config.determineRulesSet("com.test", "MyFile").isAllowed("com.other.allowed.SomeClass") shouldBe true
+ config.determineRulesSet("com.test", "SomeFile").isAllowed("com.other.allowed.SomeClass") shouldBe true
+ config.determineRulesSet("com.test.subpack", "MyFile").isAllowed("com.whatever.SomeClass") shouldBe true
+ config
+ .determineRulesSet("com.test.subpack.sub", "MyFile")
+ .isAllowed("com.whatever.SomeClass") shouldBe true
+ config.determineRulesSet("com.test.testasub", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe true
+ config.determineRulesSet("com.test.test.sub", "MyFile").isAllowed("com.allowed.SomeClass") shouldBe false
+ }
+
+ @Test def testWithSource(): Unit = {
+ val source =
+ """package com
+ |package test
+ |
+ |import com.allowed.MyClass
+ |import com.allowed.{Class1, Class2 => Renamed}
+ |import com.allowed.sub._
+ |import com.disallowed.SomeClass
+ |import com.disallowed.{Class1, Class2 => Renamed}
+ |import com.disallowed.sub._
+ |import com.disallowed.SomeClass._
+ |import com.allowed
+ |import allowed.MyOtherClass
+ |import com.disallowed
+ |import disallowed.MyOtherClass
+ |""".stripMargin
+
+ val config =
+ """
+ |
+ |
+ |
+ |""".stripMargin
+
+ val expected = List(
+ columnError(7, 0, args = List("com.disallowed.SomeClass", "mismatch strategy", "[root] 'com.test'")),
+ columnError(8, 23, args = List("com.disallowed.Class1", "mismatch strategy", "[root] 'com.test'")),
+ columnError(8, 31, args = List("com.disallowed.Class2", "mismatch strategy", "[root] 'com.test'")),
+ columnError(9, 0, args = List("com.disallowed.sub._", "mismatch strategy", "[root] 'com.test'")),
+ columnError(10, 0, args = List("com.disallowed.SomeClass._", "mismatch strategy", "[root] 'com.test'")),
+ columnError(13, 0, args = List("com.disallowed", "mismatch strategy", "[root] 'com.test'")),
+ columnError(14, 0, args = List("com.disallowed.MyOtherClass", "mismatch strategy", "[root] 'com.test'"))
+ )
+
+ val sep = java.io.File.separator
+ assertErrors(
+ expected,
+ source,
+ params = Map("inline" -> config),
+ customFileName = s"com${sep}test${sep}Test.scala"
+ )
+ }
+
+ object TestTools {
+ implicit class MatcherResultOps(matchResult: MatchResult) {
+ private def childSources: List[SourceIdentifier] =
+ matchResult.ruleContainer match {
+ case packageMatcher: PackageMatcher => packageMatcher.childSources
+ case _ => Nil
+ }
+
+ def subpackage(name: String, exactMatch: Boolean = false): MatchResult =
+ childSources
+ .collectFirst {
+ case s: Subpackage if s.name == name => MatchResult(s, exactMatch)
+ }
+ .getOrElse(fail(s"Can't find subpackage with name '$name'"))
+
+ def file(name: String): MatchResult =
+ childSources
+ .collectFirst {
+ case f: SourceIdentifier.File if f.name == name => MatchResult(f, exactMatch = true)
+ }
+ .getOrElse(fail(s"Can't find file with name '$name'"))
+ }
+
+ implicit class RuleSetOp(ruleSet: RuleSet) {
+ def isAllowed(forImport: String): Boolean = ruleSet.validateImport(forImport).allowed
+ }
+ }
+}