diff --git a/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala b/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala index 902d08174..1e8872500 100644 --- a/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala +++ b/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala @@ -117,6 +117,14 @@ object JavaMain { c.copy(runtime = c.runtime.copy(pythonPackage = x)) } text("Python package (Python only, default: root package)") + opt[Unit]("python-type-annotations") action { (x, c) => + c.copy(runtime = c.runtime.copy(pythonTypeAnnotations = true)) + } text("generate Python type annotations (Python only, default: true)") + + opt[Unit]("no-python-type-annotations") action { (x, c) => + c.copy(runtime = c.runtime.copy(pythonTypeAnnotations = false)) + } text("disable Python type annotations (Python only)") + opt[String]("nim-module") valueName("") action { (x, c) => c.copy(runtime = c.runtime.copy(nimModule = x)) } text("Path of Nim runtime module (Nim only, default: kaitai_struct_nim_runtime)") diff --git a/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala b/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala index ef613946f..34a4bf04c 100644 --- a/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala +++ b/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala @@ -90,6 +90,7 @@ case class JavaRuntimeConfig( * @param dotNetNamespace .NET (C#) namespace * @param phpNamespace PHP namespace * @param pythonPackage Python package name + * @param pythonTypeAnnotations If true, generate Python type annotations * @param nimModule Path of Nim runtime module * @param nimOpaque Directory of opaque Nim modules */ @@ -104,6 +105,7 @@ case class RuntimeConfig( dotNetNamespace: String = "Kaitai", phpNamespace: String = "", pythonPackage: String = "", + pythonTypeAnnotations: Boolean = true, nimModule: String = "kaitai_struct_nim_runtime", nimOpaque: String = "" ) diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala index ad2373fb7..72f7ba120 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala @@ -25,6 +25,10 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) import PythonCompiler._ override val translator = new PythonTranslator(typeProvider, importList, config) + + // Track attribute types for type annotation generation + private val attributeTypes = scala.collection.mutable.Map[String, (DataType, Boolean)]() + private var currentClassName: String = "" override def innerDocstrings = true @@ -43,13 +47,19 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) outHeader.puts(s"# $headerComment") // https://github.com/kaitai-io/kaitai_struct/issues/675 - // TODO: Make conditional once we'll have Python type annotations - outHeader.puts("# type: ignore") + if (!config.pythonTypeAnnotations) { + outHeader.puts("# type: ignore") + } outHeader.puts importList.add("import kaitaistruct") importList.add(s"from kaitaistruct import $kstructName, $kstreamName, BytesIO") + + // Import typing module when type annotations are enabled + if (config.pythonTypeAnnotations) { + importList.add("from typing import Any, List, Optional, Union") + } out.puts out.puts @@ -81,27 +91,63 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) PythonCompiler.externalTypeDeclaration(extType, importList, config) override def classHeader(name: String): Unit = { + currentClassName = name + attributeTypes.clear() out.puts(s"class ${type2class(name)}($kstructName):") out.inc } override def classConstructorHeader(name: String, parentType: DataType, rootClassName: String, isHybrid: Boolean, params: List[ParamDefSpec]): Unit = { - val endianAdd = if (isHybrid) ", _is_le=None" else "" - val paramsList = Utils.join(params.map((p) => paramName(p.id)), ", ", ", ", "") - - out.puts(s"def __init__(self$paramsList, _io, _parent=None, _root=None$endianAdd):") + implicit val provider: ClassTypeProvider = typeProvider + + // Build parameter list with type annotations if enabled + val paramsList = if (config.pythonTypeAnnotations) { + val paramsWithTypes = params.map { p => + val paramType = PythonCompiler.kaitaiTypeToPythonType(p.dataType) + s"${paramName(p.id)}: $paramType" + } + Utils.join(paramsWithTypes, ", ", ", ", "") + } else { + Utils.join(params.map((p) => paramName(p.id)), ", ", ", ", "") + } + + // Build constructor signature with type annotations + if (config.pythonTypeAnnotations) { + val endianAdd = if (isHybrid) ", _is_le: Optional[bool] = None" else "" + val parentTypeStr = if (name == rootClassName) "'KaitaiStruct'" else "Optional['KaitaiStruct']" + out.puts(s"def __init__(self$paramsList, _io: 'KaitaiStream', _parent: $parentTypeStr = None, _root: Optional['KaitaiStruct'] = None$endianAdd) -> None:") + } else { + val endianAdd = if (isHybrid) ", _is_le=None" else "" + out.puts(s"def __init__(self$paramsList, _io, _parent=None, _root=None$endianAdd):") + } + out.inc - out.puts("self._io = _io") - out.puts("self._parent = _parent") - if (name == rootClassName) { - out.puts("self._root = _root if _root else self") + + // Add type annotations for instance variables if enabled + if (config.pythonTypeAnnotations) { + out.puts("self._io: 'KaitaiStream' = _io") + out.puts(s"self._parent: ${if (name == rootClassName) "'KaitaiStruct'" else "Optional['KaitaiStruct']"} = _parent") + if (name == rootClassName) { + out.puts("self._root: 'KaitaiStruct' = _root if _root else self") + } else { + out.puts("self._root: Optional['KaitaiStruct'] = _root") + } + + if (isHybrid) + out.puts("self._is_le: Optional[bool] = _is_le") } else { - out.puts("self._root = _root") + out.puts("self._io = _io") + out.puts("self._parent = _parent") + if (name == rootClassName) { + out.puts("self._root = _root if _root else self") + } else { + out.puts("self._root = _root") + } + + if (isHybrid) + out.puts("self._is_le = _is_le") } - if (isHybrid) - out.puts("self._is_le = _is_le") - // Store parameters passed to us params.foreach((p) => handleAssignmentSimple(p.id, paramName(p.id))) @@ -109,6 +155,18 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) importList.add("import collections") out.puts("self._debug = collections.defaultdict(dict)") } + + // Emit forward declarations for type annotations + if (config.pythonTypeAnnotations && attributeTypes.nonEmpty) { + out.puts + out.puts("# Attribute type declarations") + attributeTypes.foreach { case (attrName, (dataType, isNullable)) => + implicit val provider: ClassTypeProvider = typeProvider + val typeStr = PythonCompiler.kaitaiTypeToPythonType(dataType, isNullable) + val privateAttrName = idToStr(NamedIdentifier(attrName)) + out.puts(s"self.$privateAttrName: $typeStr") + } + } } override def runRead(name: List[String]): Unit = { @@ -135,7 +193,8 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case Some(e) => s"_${e.toSuffix}" case None => "" } - out.puts(s"def _read$suffix(self):") + val returnType = if (config.pythonTypeAnnotations) " -> None" else "" + out.puts(s"def _read$suffix(self)$returnType:") out.inc if (isEmpty) out.puts("pass") @@ -143,7 +202,12 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def readFooter() = universalFooter - override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {} + override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { + // Store attribute type information for later use in type annotations + if (config.pythonTypeAnnotations) { + attributeTypes(idToStr(attrName)) = (attrType, isNullable) + } + } override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {} @@ -452,7 +516,13 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def instanceHeader(className: String, instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { out.puts("@property") - out.puts(s"def ${publicMemberName(instName)}(self):") + if (config.pythonTypeAnnotations) { + implicit val provider: ClassTypeProvider = typeProvider + val typeStr = PythonCompiler.kaitaiTypeToPythonType(dataType, isNullable) + out.puts(s"def ${publicMemberName(instName)}(self) -> $typeStr:") + } else { + out.puts(s"def ${publicMemberName(instName)}(self):") + } out.inc } @@ -486,7 +556,8 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def classToString(toStringExpr: Ast.expr): Unit = { out.puts - out.puts("def __repr__(self):") + val returnType = if (config.pythonTypeAnnotations) " -> str" else "" + out.puts(s"def __repr__(self)$returnType:") out.inc out.puts(s"return ${translator.translate(toStringExpr)}") out.dec @@ -546,6 +617,57 @@ object PythonCompiler extends LanguageCompilerStatic config: RuntimeConfig ): LanguageCompiler = new PythonCompiler(tp, config) + /** + * Maps Kaitai data types to Python type annotation strings + */ + def kaitaiTypeToPythonType(dataType: DataType, isNullable: Boolean = false)(implicit classTypeProvider: ClassTypeProvider): String = { + def wrapNullable(t: String): String = if (isNullable) s"Optional[$t]" else t + + val baseType = dataType match { + // Primitive types + case _: Int1Type | _: IntMultiType | _: BitsType => "int" + case _: FloatMultiType => "float" + case _: BooleanType => "bool" + case _: BytesType => "bytes" + case _: StrType => "str" + + // Complex types + case at: ArrayType => s"List[${kaitaiTypeToPythonType(at.elType)}]" + case ut: UserType => + val typeName = types2class(ut.classSpec.get.name, ut.isExternal(classTypeProvider.nowClass)) + s"'$typeName'" // Use forward reference to handle circular dependencies + + // Enum types + case et: EnumType => + val enumName = type2class(et.enumSpec.get.name.last) + s"'$enumName'" + + // Stream types + case OwnedKaitaiStreamType => "'KaitaiStream'" + case KaitaiStreamType => "'KaitaiStream'" + case KaitaiStructType => "'KaitaiStruct'" + case _: CalcKaitaiStructType => "'KaitaiStruct'" + + // Switch types - we'll use Union for these + case st: SwitchType => + val types = st.cases.values.toSet + if (types.size == 1) { + kaitaiTypeToPythonType(types.head) + } else { + val typeStrs = types.map(kaitaiTypeToPythonType(_)).mkString(", ") + s"Union[$typeStrs]" + } + + // Any type + case AnyType => "Any" + + // Default fallback + case _ => "Any" + } + + wrapNullable(baseType) + } + def idToStr(id: Identifier): String = id match { case SpecialIdentifier(name) => name