From c856429d90d58df0a1367487da5e82bc91a3e161 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 7 Oct 2016 17:31:21 +0200 Subject: [PATCH 01/17] Introduced vector expressions --- src/impala/ast.h | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/impala/ast.h b/src/impala/ast.h index b37b4c75e..a1764edd2 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1334,6 +1334,31 @@ class MapExpr : public Expr, public Args, public TypeArgs { friend class TypeSema; }; +class VectorExpr : public Expr { +public: + enum Type { + Vec2, Vec3, Vec4, + Mat2, Mat2x3, Mat2x4, + Mat3, Mat3x2, Mat3x4, + Mat4, Mat4x2, Mat4x3, + Inverse, Dot, Cross, Normalize, Length + }; + + Type type() const { return type_; } + virtual void check(NameSema&) const override; + virtual void check(BorrowSema&) const override; + +private: + virtual std::ostream& stream(std::ostream&) const override; + virtual Type check(TypeSema&, TypeExpectation) const override; + virtual thorin::Value lemit(CodeGen&) const override; + virtual const thorin::Def* remit(CodeGen&) const override; + + friend class CodeGen; + friend class Parser; + friend class TypeSema; +}; + class StmtLikeExpr : public Expr {}; class BlockExprBase : public StmtLikeExpr { From 2069a86341c825ca69e73e4ed76ea23ed3e1b28c Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 7 Oct 2016 17:40:25 +0200 Subject: [PATCH 02/17] Added matrix/vector type --- src/impala/sema/unifiable.h | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index 76a8d7ee5..ea9085585 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -162,6 +162,7 @@ enum Kind { Kind_noret, Kind_owned_ptr, Kind_simd, + Kind_matrix, Kind_struct_abs, Kind_struct_app, Kind_trait_abs, @@ -661,6 +662,29 @@ class SimdTypeNode : public ArrayTypeNode { const uint64_t size_; }; +class MatrixTypeNode : public KnownTypeNode { +public: + MatrixTypeNode(TypeTable& tt, PrimTypeKind kind, uint32_t rows, uint32_t cols) + : KnownTypeNode(tt, Kind_matrix, {}), kind_(kind), rows_(rows), cols_(cols) + {} + + uint32_t rows() const { return rows_; } + uint32_t cols() const { return cols_; } + PrimTypeKind kind() const { return kind_; } + + virtual std::ostream& stream(std::ostream&) const override; + virtual bool is_subtype(const TypeNode*) const override; + virtual bool equal(const Unifiable*) const override; + +private: + virtual Type vinstantiate(SpecializeMap&) const override; + virtual const thorin::Type* convert(CodeGen&) const override; + + PrimTypeKind kind_; + uint32_t rows_; + uint32_t cols_; +}; + /** * Represents a declared trait. * A trait consists of a name, a number of declared methods and a number of From 57a166ac09c02736cfa1bc15dfffa632b726d576 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 7 Oct 2016 17:45:24 +0200 Subject: [PATCH 03/17] Added tokens --- src/impala/sema/typetable.h | 3 +++ src/impala/tokenlist.h | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/impala/sema/typetable.h b/src/impala/sema/typetable.h index 1c4bb2efa..6606b7874 100644 --- a/src/impala/sema/typetable.h +++ b/src/impala/sema/typetable.h @@ -32,6 +32,9 @@ class TypeTable { SimdType simd_type(Type elem_type, uint64_t size) { return join(new SimdTypeNode(*this, elem_type, size)); } + MatrixType matrix_type(PrimTypeKind kind, uint32_t rows, uint32_t cols) { + return join(new MatrixTypeNode(*this, kind, rows, cols)); + } MutPtrType mut_ptr_type(Type referenced_type, int addr_space = 0) { return join(new MutPtrTypeNode(*this, referenced_type, addr_space)); } diff --git a/src/impala/tokenlist.h b/src/impala/tokenlist.h index f40e4f7ba..d6ad5d237 100644 --- a/src/impala/tokenlist.h +++ b/src/impala/tokenlist.h @@ -104,6 +104,19 @@ IMPALA_KEY(TYPEOF, "typeof") IMPALA_KEY(WHILE, "while") IMPALA_KEY(SIMD, "simd") +IMPALA_KEY(VEC2, "vec2") +IMPALA_KEY(VEC3, "vec3") +IMPALA_KEY(VEC4, "vec4") +IMPALA_KEY(MAT2, "mat2") +IMPALA_KEY(MAT3, "mat3") +IMPALA_KEY(MAT4, "mat4") +IMPALA_KEY(MAT2X3, "mat2x3") +IMPALA_KEY(MAT2X4, "mat2x4") +IMPALA_KEY(MAT3X2, "mat3x2") +IMPALA_KEY(MAT3X4, "mat3x4") +IMPALA_KEY(MAT4X2, "mat4x2") +IMPALA_KEY(MAT4X3, "mat4x3") + #undef IMPALA_KEY #ifndef IMPALA_MISC From 17fa3c10011c65b64babfe69bd091a4b29ffb128 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 11 Oct 2016 11:33:59 +0200 Subject: [PATCH 04/17] Compilation fixes --- src/impala/ast.h | 17 ++++++++--------- src/impala/parser.cpp | 17 +++++++++++++++++ src/impala/sema/typetable.h | 4 ++-- src/impala/sema/unifiable.h | 8 ++++---- src/impala/tokenlist.h | 35 +++++++++++++++++++++++------------ 5 files changed, 54 insertions(+), 27 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index a1764edd2..6dc9fbe94 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1334,17 +1334,14 @@ class MapExpr : public Expr, public Args, public TypeArgs { friend class TypeSema; }; -class VectorExpr : public Expr { -public: - enum Type { - Vec2, Vec3, Vec4, - Mat2, Mat2x3, Mat2x4, - Mat3, Mat3x2, Mat3x4, - Mat4, Mat4x2, Mat4x3, - Inverse, Dot, Cross, Normalize, Length +class VectorExpr : public Expr, public Args { +public: + enum Kind { +#define IMPALA_VEC_KEY(tok, str) tok = Token:: tok, +#include "tokenlist.h" }; - Type type() const { return type_; } + Kind kind() const { return kind_; } virtual void check(NameSema&) const override; virtual void check(BorrowSema&) const override; @@ -1357,6 +1354,8 @@ class VectorExpr : public Expr { friend class CodeGen; friend class Parser; friend class TypeSema; + + Kind kind_; }; class StmtLikeExpr : public Expr {}; diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index c048a79a1..a97eb91f7 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -229,6 +229,7 @@ class Parser { const Expr* parse_infix_expr(const Expr* lhs); const Expr* parse_postfix_expr(const Expr* lhs); const Expr* parse_primary_expr(); + const VectorExpr* parse_vector_expr(); const LiteralExpr* parse_literal_expr(); const CharExpr* parse_char_expr(); const StrExpr* parse_str_expr(); @@ -1008,6 +1009,10 @@ const Expr* Parser::parse_primary_expr() { parse_comma_list("elements of a simd expression", Token::R_BRACKET, [&] { simd->args_.push_back(parse_expr()); }); return simd; } +#define IMPALA_VEC_KEY(tok, str) case Token:: tok: +#include "impala/tokenlist.h" + return parse_vector_expr(); + #define IMPALA_LIT(itype, atype) \ case Token::LIT_##itype: #include "impala/tokenlist.h" @@ -1066,6 +1071,18 @@ const Expr* Parser::parse_primary_expr() { } } +const VectorExpr* Parser::parse_vector_expr() { + auto vec = loc(new VectorExpr()); + switch (la()) { +#define IMPALA_VEC_KEY(tok, str) case Token:: tok: vec->kind_ = VectorExpr:: tok; break; +#include "tokenlist.h" + default: THORIN_UNREACHABLE; + } + expect(Token::L_PAREN, "vector expression"); + parse_comma_list("elements of a vector expression", Token::R_PAREN, [&] { vec->args_.push_back(parse_expr()); }); + return vec; +} + const LiteralExpr* Parser::parse_literal_expr() { LiteralExpr::Kind kind; Box box; diff --git a/src/impala/sema/typetable.h b/src/impala/sema/typetable.h index 6606b7874..58128dff2 100644 --- a/src/impala/sema/typetable.h +++ b/src/impala/sema/typetable.h @@ -32,8 +32,8 @@ class TypeTable { SimdType simd_type(Type elem_type, uint64_t size) { return join(new SimdTypeNode(*this, elem_type, size)); } - MatrixType matrix_type(PrimTypeKind kind, uint32_t rows, uint32_t cols) { - return join(new MatrixTypeNode(*this, kind, rows, cols)); + MatrixType matrix_type(Type elem_type, uint32_t rows, uint32_t cols) { + return join(new MatrixTypeNode(*this, elem_type, rows, cols)); } MutPtrType mut_ptr_type(Type referenced_type, int addr_space = 0) { return join(new MutPtrTypeNode(*this, referenced_type, addr_space)); diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index ea9085585..ce77ed04b 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -42,6 +42,7 @@ class OwnedPtrTypeNode; typedef Proxy OwnedPtr class PrimTypeNode; typedef Proxy PrimType; class PtrTypeNode; typedef Proxy PtrType; class SimdTypeNode; typedef Proxy SimdType; +class MatrixTypeNode; typedef Proxy MatrixType; class StructAbsTypeNode; typedef Proxy StructAbsType; class StructAppTypeNode; typedef Proxy StructAppType; class TraitAbsNode; typedef Proxy TraitAbs; @@ -664,13 +665,13 @@ class SimdTypeNode : public ArrayTypeNode { class MatrixTypeNode : public KnownTypeNode { public: - MatrixTypeNode(TypeTable& tt, PrimTypeKind kind, uint32_t rows, uint32_t cols) - : KnownTypeNode(tt, Kind_matrix, {}), kind_(kind), rows_(rows), cols_(cols) + MatrixTypeNode(TypeTable& tt, Type elem_type, uint32_t rows, uint32_t cols) + : KnownTypeNode(tt, Kind_matrix, {elem_type}), rows_(rows), cols_(cols) {} + Type elem_kind() const { return arg(0); } uint32_t rows() const { return rows_; } uint32_t cols() const { return cols_; } - PrimTypeKind kind() const { return kind_; } virtual std::ostream& stream(std::ostream&) const override; virtual bool is_subtype(const TypeNode*) const override; @@ -680,7 +681,6 @@ class MatrixTypeNode : public KnownTypeNode { virtual Type vinstantiate(SpecializeMap&) const override; virtual const thorin::Type* convert(CodeGen&) const override; - PrimTypeKind kind_; uint32_t rows_; uint32_t cols_; }; diff --git a/src/impala/tokenlist.h b/src/impala/tokenlist.h index d6ad5d237..da0177129 100644 --- a/src/impala/tokenlist.h +++ b/src/impala/tokenlist.h @@ -104,19 +104,30 @@ IMPALA_KEY(TYPEOF, "typeof") IMPALA_KEY(WHILE, "while") IMPALA_KEY(SIMD, "simd") -IMPALA_KEY(VEC2, "vec2") -IMPALA_KEY(VEC3, "vec3") -IMPALA_KEY(VEC4, "vec4") -IMPALA_KEY(MAT2, "mat2") -IMPALA_KEY(MAT3, "mat3") -IMPALA_KEY(MAT4, "mat4") -IMPALA_KEY(MAT2X3, "mat2x3") -IMPALA_KEY(MAT2X4, "mat2x4") -IMPALA_KEY(MAT3X2, "mat3x2") -IMPALA_KEY(MAT3X4, "mat3x4") -IMPALA_KEY(MAT4X2, "mat4x2") -IMPALA_KEY(MAT4X3, "mat4x3") +#ifndef IMPALA_VEC_KEY +#define IMPALA_VEC_KEY(tok, str) IMPALA_KEY(tok, str) +#endif +IMPALA_VEC_KEY(VEC2, "vec2") +IMPALA_VEC_KEY(VEC3, "vec3") +IMPALA_VEC_KEY(VEC4, "vec4") +IMPALA_VEC_KEY(MAT2, "mat2") +IMPALA_VEC_KEY(MAT3, "mat3") +IMPALA_VEC_KEY(MAT4, "mat4") +IMPALA_VEC_KEY(MAT2X3, "mat2x3") +IMPALA_VEC_KEY(MAT2X4, "mat2x4") +IMPALA_VEC_KEY(MAT3X2, "mat3x2") +IMPALA_VEC_KEY(MAT3X4, "mat3x4") +IMPALA_VEC_KEY(MAT4X2, "mat4x2") +IMPALA_VEC_KEY(MAT4X3, "mat4x3") + +IMPALA_VEC_KEY(MAT_INVERSE, "inverse") +IMPALA_VEC_KEY(VEC_DOT, "dot") +IMPALA_VEC_KEY(VEC_CROSS, "cross") +IMPALA_VEC_KEY(VEC_LENGTH, "length") +IMPALA_VEC_KEY(VEC_NORMALIZE, "normalize") + +#undef IMPALA_VEC_KEY #undef IMPALA_KEY #ifndef IMPALA_MISC From ea13eda61a2a18d31889880cbb132312e25bbc44 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 11 Oct 2016 12:40:29 +0200 Subject: [PATCH 05/17] More vector stuff --- src/impala/ast.h | 2 ++ src/impala/sema/borrowsema.cpp | 5 +++++ src/impala/sema/namesema.cpp | 5 +++++ src/impala/sema/typesema.cpp | 34 ++++++++++++++++++++++++++++++++++ src/impala/sema/typetable.h | 2 +- src/impala/sema/unifiable.cpp | 16 +++++++++++----- src/impala/sema/unifiable.h | 3 +++ src/impala/stream.cpp | 6 ++++++ 8 files changed, 67 insertions(+), 6 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index 6dc9fbe94..fb8ce42a8 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1351,6 +1351,8 @@ class VectorExpr : public Expr, public Args { virtual thorin::Value lemit(CodeGen&) const override; virtual const thorin::Def* remit(CodeGen&) const override; + bool check_vector_args(TypeSema&) const; + friend class CodeGen; friend class Parser; friend class TypeSema; diff --git a/src/impala/sema/borrowsema.cpp b/src/impala/sema/borrowsema.cpp index e79e8a063..22ff01e6f 100644 --- a/src/impala/sema/borrowsema.cpp +++ b/src/impala/sema/borrowsema.cpp @@ -243,6 +243,11 @@ void MapExpr::check(BorrowSema& sema) const { arg->check(sema); } +void VectorExpr::check(BorrowSema& sema) const { + for (auto arg : args()) + arg->check(sema); +} + void IfExpr::check(BorrowSema& sema) const { cond()->check(sema); then_expr()->check(sema); diff --git a/src/impala/sema/namesema.cpp b/src/impala/sema/namesema.cpp index 69b0d435d..057891483 100644 --- a/src/impala/sema/namesema.cpp +++ b/src/impala/sema/namesema.cpp @@ -368,6 +368,11 @@ void MapExpr::check(NameSema& sema) const { arg->check(sema); } +void VectorExpr::check(NameSema& sema) const { + for (auto arg : args()) + arg->check(sema); +} + void IfExpr::check(NameSema& sema) const { cond()->check(sema); then_expr()->check(sema); diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 31a653e48..8b7e6b0d0 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -1194,6 +1194,40 @@ Type MapExpr::check_as_method_call(TypeSema& sema, TypeExpectation expected) con return sema.type_error(); } +Type VectorExpr::check(TypeSema& sema, TypeExpectation expected) const { + if (!num_args()) + error(this) << "arguments expected\n"; + + for (auto arg : args()) { + sema.check(arg); + } + + switch (kind()) { + case VEC3: + if (!check_vector_args(sema)) return sema.type_error(); + return sema.matrix_type(arg(0)->type(), 3); + break; + default: break; + } + return sema.type_error(); +} + +bool VectorExpr::check_vector_args(TypeSema& sema) const { + auto arg0 = arg(0); + for (auto arg : args()) { + if (arg->type() != arg0->type()) { + error(this) << "mismatching types in vector expression\n"; + return false; + } + if (!arg->type().isa() || + !arg->type().isa()) { + error(this) << "incorrect type for vector element\n"; + return false; + } + } + return true; +} + Type BlockExprBase::check(TypeSema& sema, TypeExpectation expected) const { THORIN_PUSH(sema.cur_block_, this); for (auto stmt : stmts()) diff --git a/src/impala/sema/typetable.h b/src/impala/sema/typetable.h index 58128dff2..9e596a36e 100644 --- a/src/impala/sema/typetable.h +++ b/src/impala/sema/typetable.h @@ -32,7 +32,7 @@ class TypeTable { SimdType simd_type(Type elem_type, uint64_t size) { return join(new SimdTypeNode(*this, elem_type, size)); } - MatrixType matrix_type(Type elem_type, uint32_t rows, uint32_t cols) { + MatrixType matrix_type(Type elem_type, uint32_t rows, uint32_t cols = 1) { return join(new MatrixTypeNode(*this, elem_type, rows, cols)); } MutPtrType mut_ptr_type(Type referenced_type, int addr_space = 0) { diff --git a/src/impala/sema/unifiable.cpp b/src/impala/sema/unifiable.cpp index 905029ecf..82f92f7c5 100644 --- a/src/impala/sema/unifiable.cpp +++ b/src/impala/sema/unifiable.cpp @@ -306,9 +306,17 @@ bool TraitAppNode::equal(const Unifiable* other) const { bool SimdTypeNode::equal(const Unifiable* other) const { assert(this->is_unified()); - return Unifiable::equal(other) && (this->size() == other->as()->size()); + return Unifiable::equal(other) && (size() == other->as()->size()); } +bool MatrixTypeNode::equal(const Unifiable* other) const { + assert(this->is_unified()); + return Unifiable::equal(other) && + rows() == other->as()->rows() && + cols() == other->as()->cols(); +} + + //------------------------------------------------------------------------------ /* @@ -332,10 +340,8 @@ bool DefiniteArrayTypeNode::is_subtype(const TypeNode* other) const { return dim_eq && other->isa() && elem_type()->is_subtype(*other->as()->elem_type()); } - -bool SimdTypeNode::is_subtype(const TypeNode* other) const { - return this->equal(other); -} +bool SimdTypeNode::is_subtype(const TypeNode* other) const { return equal(other); } +bool MatrixTypeNode::is_subtype(const TypeNode* other) const { return equal(other); } /* * TODO merge this code with equal diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index ce77ed04b..4c7e1a101 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -673,6 +673,9 @@ class MatrixTypeNode : public KnownTypeNode { uint32_t rows() const { return rows_; } uint32_t cols() const { return cols_; } + bool is_vector() const { return cols_ == 1; } + bool is_matrix() const { return !is_vector(); } + virtual std::ostream& stream(std::ostream&) const override; virtual bool is_subtype(const TypeNode*) const override; virtual bool equal(const Unifiable*) const override; diff --git a/src/impala/stream.cpp b/src/impala/stream.cpp index 0331c1ea0..5fe111b56 100644 --- a/src/impala/stream.cpp +++ b/src/impala/stream.cpp @@ -80,6 +80,12 @@ std::ostream& IndefiniteArrayTypeNode::stream(std::ostream& os) const { return s std::ostream& SimdTypeNode::stream(std::ostream& os) const { return streamf(os, "simd[% * %]", elem_type(), size()); } std::ostream& StructAbsTypeNode::stream(std::ostream& os) const { return os << struct_decl_->symbol(); } +std::ostream& MatrixTypeNode::stream(std::ostream& os) const { + return is_vector() ? + streamf(os, "vec%", rows()) : + streamf(os, "mat%x%", rows(), cols()); +} + std::ostream& StructAppTypeNode::stream(std::ostream& os) const { os << struct_abs_type()->struct_decl()->symbol(); if (num_args() != 0) From 6b80ac0ffe040864be63637b757752607e38d494 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 11 Oct 2016 16:44:11 +0200 Subject: [PATCH 06/17] Added simple code generation --- src/impala/ast.h | 1 - src/impala/emit.cpp | 33 +++++++++++++++++++++++++++++++++ src/impala/sema/unifiable.cpp | 4 ++++ src/impala/sema/unifiable.h | 2 +- src/impala/stream.cpp | 9 +++++++++ 5 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index fb8ce42a8..ad311770c 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1348,7 +1348,6 @@ class VectorExpr : public Expr, public Args { private: virtual std::ostream& stream(std::ostream&) const override; virtual Type check(TypeSema&, TypeExpectation) const override; - virtual thorin::Value lemit(CodeGen&) const override; virtual const thorin::Def* remit(CodeGen&) const override; bool check_vector_args(TypeSema&) const; diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index 9879a1ed4..5f53a6d16 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -182,6 +182,14 @@ const thorin::Type* SimdTypeNode::convert(CodeGen& cg) const { return cg.world().type(scalar->as()->primtype_kind(), size()); } +const thorin::Type* MatrixTypeNode::convert(CodeGen& cg) const { + int n = rows() * cols(); + auto elem = cg.convert(elem_type()); + Array args(n); + for (int i = 0; i < n; i++) args[i] = elem; + return cg.world().tuple_type(args); +} + /* * Decls and Function */ @@ -577,6 +585,31 @@ const Def* MapExpr::remit(CodeGen& cg, State state, Location eval_loc) const { THORIN_UNREACHABLE; } +const Def* VectorExpr::remit(CodeGen& cg) const { + switch (kind()) { + case VEC2: + case VEC3: + case VEC4: + case MAT2: + case MAT3: + case MAT4: + case MAT2X3: + case MAT2X4: + case MAT3X2: + case MAT3X4: + case MAT4X2: + case MAT4X3: + { + int i = 0; + Array defs(num_args()); + for (auto arg : args()) defs[i++] = cg.remit(arg); + return cg.world().tuple(defs, loc()); + } + default: break; + } + THORIN_UNREACHABLE; +} + Value FieldExpr::lemit(CodeGen& cg) const { return Value::create_agg(cg.lemit(lhs()), cg.world().literal_qu32(index(), loc())); } diff --git a/src/impala/sema/unifiable.cpp b/src/impala/sema/unifiable.cpp index 82f92f7c5..5913c5a49 100644 --- a/src/impala/sema/unifiable.cpp +++ b/src/impala/sema/unifiable.cpp @@ -481,6 +481,10 @@ Type SimdTypeNode::vinstantiate(SpecializeMap& map) const { return map[this] = *typetable().simd_type(elem_type()->specialize(map), size()); } +Type MatrixTypeNode::vinstantiate(SpecializeMap& map) const { + return map[this] = *typetable().matrix_type(elem_type()->specialize(map), rows(), cols()); +} + Type FnTypeNode::vinstantiate(SpecializeMap& map) const { return map[this] = *typetable().fn_type(specialize_args(map)); } diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index 4c7e1a101..eedf47f40 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -669,7 +669,7 @@ class MatrixTypeNode : public KnownTypeNode { : KnownTypeNode(tt, Kind_matrix, {elem_type}), rows_(rows), cols_(cols) {} - Type elem_kind() const { return arg(0); } + Type elem_type() const { return arg(0); } uint32_t rows() const { return rows_; } uint32_t cols() const { return cols_; } diff --git a/src/impala/stream.cpp b/src/impala/stream.cpp index 5fe111b56..8614abbe9 100644 --- a/src/impala/stream.cpp +++ b/src/impala/stream.cpp @@ -387,6 +387,15 @@ std::ostream& SimdExpr::stream(std::ostream& os) const { return stream_list(os, args(), [&](const Expr* expr) { os << expr; }, "simd[", "]"); } +std::ostream& VectorExpr::stream(std::ostream& os) const { + switch (kind()) { +#define IMPALA_KEY_VEC(tok, str) case tok : os << str; break; +#include "impala/tokenlist.h" + default: THORIN_UNREACHABLE; + } + return stream_list(os, args(), [&](const Expr* expr) { os << expr; }, "(", ")"); +} + std::ostream& PrefixExpr::stream(std::ostream& os) const { Prec r = PrecTable::prefix_r[kind()]; Prec old = prec; From dd2026738d32da4326942688880c73010c99fdae Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 11 Oct 2016 16:59:51 +0200 Subject: [PATCH 07/17] Fix parser --- src/impala/parser.cpp | 15 ++++++++++++++- src/impala/sema/typesema.cpp | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index a97eb91f7..ffd4f3746 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -69,7 +69,19 @@ case Token::L_BRACE: \ case Token::RUN_BLOCK: \ case Token::L_BRACKET: \ - case Token::SIMD + case Token::SIMD: \ + case Token::VEC2: \ + case Token::VEC3: \ + case Token::VEC4: \ + case Token::MAT2: \ + case Token::MAT3: \ + case Token::MAT4: \ + case Token::MAT2X3: \ + case Token::MAT2X4: \ + case Token::MAT3X2: \ + case Token::MAT3X4: \ + case Token::MAT4X2: \ + case Token::MAT4X3 #define STMT_NOT_EXPR \ Token::LET: \ @@ -1078,6 +1090,7 @@ const VectorExpr* Parser::parse_vector_expr() { #include "tokenlist.h" default: THORIN_UNREACHABLE; } + lex(); expect(Token::L_PAREN, "vector expression"); parse_comma_list("elements of a vector expression", Token::R_PAREN, [&] { vec->args_.push_back(parse_expr()); }); return vec; diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 8b7e6b0d0..ce170bce5 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -1219,7 +1219,7 @@ bool VectorExpr::check_vector_args(TypeSema& sema) const { error(this) << "mismatching types in vector expression\n"; return false; } - if (!arg->type().isa() || + if (!arg->type().isa() && !arg->type().isa()) { error(this) << "incorrect type for vector element\n"; return false; From a497ec1232950f58407c18de8b2d536a6c9e8c8a Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 11 Oct 2016 17:21:13 +0200 Subject: [PATCH 08/17] Added AST types for vectors --- src/impala/ast.h | 19 +++++++++++ src/impala/parser.cpp | 61 ++++++++++++++++++++++++++++------ src/impala/sema/borrowsema.cpp | 4 +++ src/impala/sema/namesema.cpp | 4 +++ src/impala/sema/typesema.cpp | 13 ++++++-- src/impala/sema/unifiable.h | 6 ++-- src/impala/stream.cpp | 9 +++-- 7 files changed, 98 insertions(+), 18 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index ad311770c..569cbc020 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -422,6 +422,25 @@ class SimdASTType : public ArrayASTType { friend class Parser; }; +class MatrixASTType : public ArrayASTType { +public: + uint32_t rows() const { return rows_; } + uint32_t cols() const { return cols_; } + + bool is_vector() const { return cols_ == 1; } + + virtual std::ostream& stream(std::ostream&) const override; + virtual void check(NameSema&) const override; + virtual void check(BorrowSema&) const override; + +private: + virtual Type check(TypeSema&) const override; + + thorin::u32 rows_, cols_; + + friend class Parser; +}; + //------------------------------------------------------------------------------ /* diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index ffd4f3746..8c4cfe1b0 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -202,16 +202,17 @@ class Parser { void parse_return_param(Fn* fn); // types - const ASTType* parse_type(); - const ArrayASTType* parse_array_type(); - const Typeof* parse_typeof(); - const ASTType* parse_return_type(bool&); - const FnASTType* parse_fn_type(); - const PrimASTType* parse_prim_type(); - const PtrASTType* parse_ptr_type(); - const TupleASTType* parse_tuple_type(); - const SimdASTType* parse_simd_type(); - const ASTTypeApp* parse_type_app(); + const ASTType* parse_type(); + const ArrayASTType* parse_array_type(); + const Typeof* parse_typeof(); + const ASTType* parse_return_type(bool&); + const FnASTType* parse_fn_type(); + const PrimASTType* parse_prim_type(); + const PtrASTType* parse_ptr_type(); + const TupleASTType* parse_tuple_type(); + const SimdASTType* parse_simd_type(); + const MatrixASTType* parse_matrix_type(); + const ASTTypeApp* parse_type_app(); enum class BodyMode { None, Optional, Mandatory }; @@ -722,6 +723,21 @@ const ASTType* Parser::parse_type() { case Token::AND: case Token::ANDAND: return parse_ptr_type(); case Token::SIMD: return parse_simd_type(); + + case Token::VEC2: + case Token::VEC3: + case Token::VEC4: + case Token::MAT2: + case Token::MAT3: + case Token::MAT4: + case Token::MAT2X3: + case Token::MAT2X4: + case Token::MAT3X2: + case Token::MAT3X4: + case Token::MAT4X2: + case Token::MAT4X3: + return parse_matrix_type(); + default: { error("type", ""); auto error_type = new ErrorASTType(prev_loc()); @@ -868,6 +884,31 @@ const SimdASTType* Parser::parse_simd_type() { return simd; } +const MatrixASTType* Parser::parse_matrix_type() { + auto mat = loc(new MatrixASTType()); + mat->cols_ = 1; + switch (la()) { + case Token::VEC2: mat->rows_ = 2; mat->cols_ = 1; break; + case Token::VEC3: mat->rows_ = 3; mat->cols_ = 1; break; + case Token::VEC4: mat->rows_ = 4; mat->cols_ = 1; break; + case Token::MAT2: mat->rows_ = 2; mat->cols_ = 2; break; + case Token::MAT3: mat->rows_ = 2; mat->cols_ = 3; break; + case Token::MAT4: mat->rows_ = 2; mat->cols_ = 4; break; + case Token::MAT2X3: mat->rows_ = 2; mat->cols_ = 3; break; + case Token::MAT2X4: mat->rows_ = 2; mat->cols_ = 4; break; + case Token::MAT3X2: mat->rows_ = 3; mat->cols_ = 2; break; + case Token::MAT3X4: mat->rows_ = 3; mat->cols_ = 4; break; + case Token::MAT4X2: mat->rows_ = 4; mat->cols_ = 2; break; + case Token::MAT4X3: mat->rows_ = 4; mat->cols_ = 3; break; + default: THORIN_UNREACHABLE; + } + lex(); + expect(Token::L_BRACKET, "vector or matrix type"); + mat->elem_type_ = parse_type(); + expect(Token::R_BRACKET, "vector or matrix type"); + return mat; +} + /* * expressions */ diff --git a/src/impala/sema/borrowsema.cpp b/src/impala/sema/borrowsema.cpp index 22ff01e6f..962a5d2c6 100644 --- a/src/impala/sema/borrowsema.cpp +++ b/src/impala/sema/borrowsema.cpp @@ -76,6 +76,10 @@ void SimdASTType::check(BorrowSema& sema) const { elem_type()->check(sema); } +void MatrixASTType::check(BorrowSema& sema) const { + elem_type()->check(sema); +} + //------------------------------------------------------------------------------ diff --git a/src/impala/sema/namesema.cpp b/src/impala/sema/namesema.cpp index 057891483..a8096684d 100644 --- a/src/impala/sema/namesema.cpp +++ b/src/impala/sema/namesema.cpp @@ -169,6 +169,10 @@ void SimdASTType::check(NameSema& sema) const { elem_type()->check(sema); } +void MatrixASTType::check(NameSema& sema) const { + elem_type()->check(sema); +} + //------------------------------------------------------------------------------ /* diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index ce170bce5..bd676bcd3 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -365,6 +365,16 @@ Type SimdASTType::check(TypeSema& sema) const { } } +Type MatrixASTType::check(TypeSema& sema) const { + auto type = sema.check(elem_type()); + if (type.isa() || type.isa()) + return sema.matrix_type(type, rows(), cols()); + else { + error(this) << "vector or matrix types can only be used with primitive or simd types\n"; + return sema.type_error(); + } +} + //------------------------------------------------------------------------------ Type ValueDecl::check(TypeSema& sema) const { return check(sema, Type()); } @@ -1219,8 +1229,7 @@ bool VectorExpr::check_vector_args(TypeSema& sema) const { error(this) << "mismatching types in vector expression\n"; return false; } - if (!arg->type().isa() && - !arg->type().isa()) { + if (!arg->type().isa() && !arg->type().isa()) { error(this) << "incorrect type for vector element\n"; return false; } diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index eedf47f40..9f5e06a65 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -663,18 +663,16 @@ class SimdTypeNode : public ArrayTypeNode { const uint64_t size_; }; -class MatrixTypeNode : public KnownTypeNode { +class MatrixTypeNode : public ArrayTypeNode { public: MatrixTypeNode(TypeTable& tt, Type elem_type, uint32_t rows, uint32_t cols) - : KnownTypeNode(tt, Kind_matrix, {elem_type}), rows_(rows), cols_(cols) + : ArrayTypeNode(tt, Kind_matrix, { elem_type }), rows_(rows), cols_(cols) {} - Type elem_type() const { return arg(0); } uint32_t rows() const { return rows_; } uint32_t cols() const { return cols_; } bool is_vector() const { return cols_ == 1; } - bool is_matrix() const { return !is_vector(); } virtual std::ostream& stream(std::ostream&) const override; virtual bool is_subtype(const TypeNode*) const override; diff --git a/src/impala/stream.cpp b/src/impala/stream.cpp index 8614abbe9..945d1f20c 100644 --- a/src/impala/stream.cpp +++ b/src/impala/stream.cpp @@ -82,8 +82,8 @@ std::ostream& StructAbsTypeNode::stream(std::ostream& os) const { return os << s std::ostream& MatrixTypeNode::stream(std::ostream& os) const { return is_vector() ? - streamf(os, "vec%", rows()) : - streamf(os, "mat%x%", rows(), cols()); + streamf(os, "vec%[%]", rows(), elem_type()) : + streamf(os, "mat%x%[%]", rows(), cols(), elem_type()); } std::ostream& StructAppTypeNode::stream(std::ostream& os) const { @@ -121,6 +121,11 @@ std::ostream& PtrASTType::stream(std::ostream& os) const { std::ostream& DefiniteArrayASTType::stream(std::ostream& os) const { return streamf(os, "[% * %]", elem_type(), dim()); } std::ostream& IndefiniteArrayASTType::stream(std::ostream& os) const { return streamf(os, "[%]", elem_type()); } std::ostream& SimdASTType::stream(std::ostream& os) const { return streamf(os, "simd[% * %]", elem_type(), size()); } +std::ostream& MatrixASTType::stream(std::ostream& os) const { + return is_vector() ? + streamf(os, "vec%[%]", rows(), elem_type()) : + streamf(os, "mat%x%[%]", rows(), cols(), elem_type()); +} std::ostream& TupleASTType::stream(std::ostream& os) const { return stream_list(os, args(), [&](const ASTType* type) { os << type; }, "(", ")"); From 0bfdcf38b0b52f43119bc8da6a876c9e84f0323c Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Thu, 13 Oct 2016 17:47:40 +0200 Subject: [PATCH 09/17] More type checks for vectors --- src/impala/ast.cpp | 2 +- src/impala/ast.h | 15 +- src/impala/emit.cpp | 53 +++++-- src/impala/parser.cpp | 14 +- src/impala/sema/borrowsema.cpp | 2 +- src/impala/sema/namesema.cpp | 2 +- src/impala/sema/typesema.cpp | 263 ++++++++++++++++++++++++++++----- src/impala/sema/unifiable.h | 1 + src/impala/stream.cpp | 2 +- src/impala/tokenlist.h | 44 +++--- 10 files changed, 315 insertions(+), 83 deletions(-) diff --git a/src/impala/ast.cpp b/src/impala/ast.cpp index 7915895b7..21681a304 100644 --- a/src/impala/ast.cpp +++ b/src/impala/ast.cpp @@ -86,7 +86,7 @@ bool MapExpr::is_lvalue() const { } bool PrefixExpr::is_lvalue() const { return (kind() == MUL || kind() == AND) && rhs()->is_lvalue(); } -bool FieldExpr::is_lvalue() const { return lhs()->is_lvalue(); } +bool FieldExpr::is_lvalue() const { return lhs()->is_lvalue() && !lhs()->type().isa(); } bool CastExpr::is_lvalue() const { return lhs()->is_lvalue(); } //------------------------------------------------------------------------------ diff --git a/src/impala/ast.h b/src/impala/ast.h index 569cbc020..bf60b9a34 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1116,6 +1116,8 @@ class InfixExpr : public Expr { virtual std::ostream& stream(std::ostream&) const override; virtual Type check(TypeSema&, TypeExpectation) const override; + Type check_arith_op(TypeSema&) const; + Kind kind_; AutoPtr lhs_; AutoPtr rhs_; @@ -1157,11 +1159,11 @@ class FieldExpr : public Expr { const Identifier* identifier() const { return identifier_; } Symbol symbol() const { return identifier()->symbol(); } uint32_t index() const { return index_; } + const std::vector& swizzle() const { return swizzle_; } virtual bool is_lvalue() const override; virtual void take_address() const override; virtual void check(NameSema&) const override; virtual void check(BorrowSema&) const override; - Type check_as_struct(TypeSema&, Type) const; private: virtual std::ostream& stream(std::ostream&) const override; @@ -1169,9 +1171,13 @@ class FieldExpr : public Expr { virtual thorin::Value lemit(CodeGen&) const override; virtual const thorin::Def* remit(CodeGen&) const override; + Type check_as_struct(TypeSema&, Type) const; + Type check_as_matrix(TypeSema&, Type) const; + AutoPtr lhs_; AutoPtr identifier_; mutable uint32_t index_ = uint32_t(-1); + mutable std::vector swizzle_; friend class Parser; friend class MapExpr; // remove this @@ -1353,10 +1359,10 @@ class MapExpr : public Expr, public Args, public TypeArgs { friend class TypeSema; }; -class VectorExpr : public Expr, public Args { +class MatrixExpr : public Expr, public Args { public: enum Kind { -#define IMPALA_VEC_KEY(tok, str) tok = Token:: tok, +#define IMPALA_MAT_KEY(tok, str, r, c) tok = Token:: tok, #include "tokenlist.h" }; @@ -1369,7 +1375,8 @@ class VectorExpr : public Expr, public Args { virtual Type check(TypeSema&, TypeExpectation) const override; virtual const thorin::Def* remit(CodeGen&) const override; - bool check_vector_args(TypeSema&) const; + Type check_vector_args(TypeSema&, uint32_t) const; + Type check_matrix_args(TypeSema&, uint32_t, uint32_t) const; friend class CodeGen; friend class Parser; diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index 5f53a6d16..3a7193cf0 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -183,11 +183,7 @@ const thorin::Type* SimdTypeNode::convert(CodeGen& cg) const { } const thorin::Type* MatrixTypeNode::convert(CodeGen& cg) const { - int n = rows() * cols(); - auto elem = cg.convert(elem_type()); - Array args(n); - for (int i = 0; i < n; i++) args[i] = elem; - return cg.world().tuple_type(args); + return cg.world().definite_array_type(cg.convert(elem_type()), size()); } /* @@ -585,7 +581,7 @@ const Def* MapExpr::remit(CodeGen& cg, State state, Location eval_loc) const { THORIN_UNREACHABLE; } -const Def* VectorExpr::remit(CodeGen& cg) const { +const Def* MatrixExpr::remit(CodeGen& cg) const { switch (kind()) { case VEC2: case VEC3: @@ -598,23 +594,56 @@ const Def* VectorExpr::remit(CodeGen& cg) const { case MAT3X2: case MAT3X4: case MAT4X2: - case MAT4X3: - { - int i = 0; - Array defs(num_args()); - for (auto arg : args()) defs[i++] = cg.remit(arg); - return cg.world().tuple(defs, loc()); + case MAT4X3: { + auto mat = type().as(); + int i = 0, n = mat->size(); + Array defs(n); + for (auto arg : args()) { + auto def = cg.remit(arg); + if (arg->type().isa()) { + for (int j = 0, m = def->size(); j < m; j++) + defs[i++] = cg.world().extract(def, j, loc()); + } else { + defs[i++] = def; + } + } + + if (i == 1) { + // repetition constructor, like so: vec4(1.0f) + if (mat->is_vector()) { + for (; i < n; i++) defs[i] = defs[0]; + } else { + // for matrices, this means defs[0] * identity + auto z = cg.world().zero(cg.convert(mat->elem_type()), loc()); + for (int i = 0, n = mat->rows(); i < n; i++) { + for (int j = 0, m = mat->cols(); j < m; j++) + defs[i * mat->cols() + j] = i == j ? defs[0] : z; + } + } } + return cg.world().definite_array(defs, loc()); + } default: break; } THORIN_UNREACHABLE; } Value FieldExpr::lemit(CodeGen& cg) const { + assert(!lhs()->type().isa()); return Value::create_agg(cg.lemit(lhs()), cg.world().literal_qu32(index(), loc())); } const Def* FieldExpr::remit(CodeGen& cg) const { + if (lhs()->type().isa()) { + int i = 0, n = swizzle().size(); + auto def = cg.remit(lhs()); + Array defs(n); + for (auto s : swizzle()) { + defs[i++] = cg.world().extract(def, s, loc()); + } + return n > 1 ? cg.world().definite_array(defs, loc()) : defs[0]; + } + return cg.extract(cg.remit(lhs()), index(), loc()); } diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index 8c4cfe1b0..be28507de 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -242,7 +242,7 @@ class Parser { const Expr* parse_infix_expr(const Expr* lhs); const Expr* parse_postfix_expr(const Expr* lhs); const Expr* parse_primary_expr(); - const VectorExpr* parse_vector_expr(); + const MatrixExpr* parse_matrix_expr(); const LiteralExpr* parse_literal_expr(); const CharExpr* parse_char_expr(); const StrExpr* parse_str_expr(); @@ -1062,9 +1062,9 @@ const Expr* Parser::parse_primary_expr() { parse_comma_list("elements of a simd expression", Token::R_BRACKET, [&] { simd->args_.push_back(parse_expr()); }); return simd; } -#define IMPALA_VEC_KEY(tok, str) case Token:: tok: +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok: #include "impala/tokenlist.h" - return parse_vector_expr(); + return parse_matrix_expr(); #define IMPALA_LIT(itype, atype) \ case Token::LIT_##itype: @@ -1124,11 +1124,11 @@ const Expr* Parser::parse_primary_expr() { } } -const VectorExpr* Parser::parse_vector_expr() { - auto vec = loc(new VectorExpr()); +const MatrixExpr* Parser::parse_matrix_expr() { + auto vec = loc(new MatrixExpr()); switch (la()) { -#define IMPALA_VEC_KEY(tok, str) case Token:: tok: vec->kind_ = VectorExpr:: tok; break; -#include "tokenlist.h" +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok: vec->kind_ = MatrixExpr:: tok; break; +#include "impala/tokenlist.h" default: THORIN_UNREACHABLE; } lex(); diff --git a/src/impala/sema/borrowsema.cpp b/src/impala/sema/borrowsema.cpp index 962a5d2c6..c84bb4c27 100644 --- a/src/impala/sema/borrowsema.cpp +++ b/src/impala/sema/borrowsema.cpp @@ -247,7 +247,7 @@ void MapExpr::check(BorrowSema& sema) const { arg->check(sema); } -void VectorExpr::check(BorrowSema& sema) const { +void MatrixExpr::check(BorrowSema& sema) const { for (auto arg : args()) arg->check(sema); } diff --git a/src/impala/sema/namesema.cpp b/src/impala/sema/namesema.cpp index a8096684d..c68c81f86 100644 --- a/src/impala/sema/namesema.cpp +++ b/src/impala/sema/namesema.cpp @@ -372,7 +372,7 @@ void MapExpr::check(NameSema& sema) const { arg->check(sema); } -void VectorExpr::check(NameSema& sema) const { +void MatrixExpr::check(NameSema& sema) const { for (auto arg : args()) arg->check(sema); } diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index bd676bcd3..9b316d075 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -65,6 +65,7 @@ class TypeSema : public TypeTable { bool is_float(Type t); bool expect_int(const Expr*); bool expect_int_or_bool(const Expr*); + bool expect_num(Type, const Expr*); void expect_num(const Expr*); Type expect_type(const Expr* expr, Type found, TypeExpectation expected); Type expect_type(const Expr* expr, TypeExpectation expected) { return expect_type(expr, expr->type(), expected); } @@ -172,11 +173,16 @@ bool TypeSema::expect_int_or_bool(const Expr* expr) { return true; } -void TypeSema::expect_num(const Expr* expr) { - auto t = scalar_type(expr); - - if (!t->is_error() && !t->is_bool() && !is_int(t) && !is_float(t)) +bool TypeSema::expect_num(Type t, const Expr* expr) { + if (!t->is_error() && !t->is_bool() && !is_int(t) && !is_float(t)) { error(expr) << "expected number type but found " << t << "\n"; + return false; + } + return true; +} + +void TypeSema::expect_num(const Expr* expr) { + expect_num(scalar_type(expr), expr); } Type TypeSema::expect_type(const Expr* expr, Type found_type, TypeExpectation expected) { @@ -367,12 +373,17 @@ Type SimdASTType::check(TypeSema& sema) const { Type MatrixASTType::check(TypeSema& sema) const { auto type = sema.check(elem_type()); - if (type.isa() || type.isa()) + if (type.isa() || type.isa()) { + if (!sema.is_int(type) && !sema.is_float(type)) { + error(this) << "only floating point and integer types are supported for " + << (is_vector() ? "vectors" : "matrices") << "\n"; + return sema.type_error(); + } return sema.matrix_type(type, rows(), cols()); - else { - error(this) << "vector or matrix types can only be used with primitive or simd types\n"; - return sema.type_error(); } + + error(this) << "vector or matrix types only support primitive or simd types\n"; + return sema.type_error(); } //------------------------------------------------------------------------------ @@ -840,13 +851,10 @@ Type InfixExpr::check(TypeSema& sema, TypeExpectation expected) const { return sema.type_bool(); case ADD: case SUB: - case MUL: case DIV: + case MUL: case REM: { - auto type = sema.check(lhs(), sema.check(rhs(), expected)); - sema.expect_num(lhs()); - sema.expect_num(rhs()); - return type; + return check_arith_op(sema); } case SHL: case SHR: { @@ -873,12 +881,9 @@ Type InfixExpr::check(TypeSema& sema, TypeExpectation expected) const { case MUL_ASGN: case DIV_ASGN: case REM_ASGN: { - sema.check(rhs(), sema.check(lhs())); - if (sema.expect_lvalue(lhs())) { - sema.expect_num(lhs()); - sema.expect_num(rhs()); + check_arith_op(sema); + if (sema.expect_lvalue(lhs())) return sema.unit(); - } break; } case AND_ASGN: @@ -901,6 +906,60 @@ Type InfixExpr::check(TypeSema& sema, TypeExpectation expected) const { return sema.type_error(); } +inline Type matrix_elem_type(Type t) { + if (auto mat = t.isa()) + return mat->elem_type(); + return t; +} + +Type InfixExpr::check_arith_op(TypeSema& sema) const { + auto a = sema.check(lhs()); + auto b = sema.check(rhs()); + + auto elem_a = matrix_elem_type(a); + auto elem_b = matrix_elem_type(b); + bool scalar_a = a == elem_a; + bool scalar_b = b == elem_b; + + if (kind() == REM) { + // For REM, the types must be integers + if (!sema.is_int(elem_a) || !sema.is_int(elem_b)) { + error(this) << "the modulus '%' operator is only valid on integers\n"; + return sema.type_error(); + } + } else { + sema.expect_num(elem_a, lhs()); + sema.expect_num(elem_b, rhs()); + } + + // if operands are both scalars or both vectors, types must be equal + if ((scalar_a & scalar_b) && a == b) return a; + if ((scalar_a ^ scalar_b) && elem_a == elem_b) { return scalar_a ? b : a; } + + if (!scalar_a && !scalar_b && elem_a == elem_b) { + auto mat_a = elem_a.as(); + auto mat_b = elem_b.as(); + + if (kind() == MUL) { + // vector * matrix, matrix * vector, or matrix * matrix multiplication + if (!mat_a->is_vector()) { + if (mat_a->cols() == mat_b->rows()) return mat_b; + } else { + if (mat_a->rows() == mat_b->rows()) return mat_a; + } + } else if (mat_a->rows() == mat_b->rows()) { + // addition, subtraction, division, ... + return mat_a; + } + } + + if (!a->is_error() && !b->is_error()) { + error(this) << "types do not match for operator '" << Token(loc(), (Token::Kind)kind()) + << "', got " << a << " and " << b << " \n"; + } + return sema.type_error(); +} + Type PostfixExpr::check(TypeSema& sema, TypeExpectation expected) const { // TODO check if operator supports the type sema.check(lhs(), expected); @@ -1108,6 +1167,10 @@ Type TypeSema::check_call(const MapExpr* expr, FnType fn_poly, const ASTTypes& t } Type FieldExpr::check(TypeSema& sema, TypeExpectation expected) const { + auto ltype = sema.check(lhs()); + if (ltype.isa()) + return check_as_matrix(sema, expected.type()); + if (auto type = check_as_struct(sema, expected.type())) return type; @@ -1136,6 +1199,45 @@ Type FieldExpr::check_as_struct(TypeSema& sema, Type expected) const { return Type(); } +Type FieldExpr::check_as_matrix(TypeSema& sema, Type expected) const { + auto ltype = sema.check(lhs()).as(); + if (ltype->is_vector()) { + const char* str = symbol().str(); + uint32_t len_dst = strlen(str); + if (len_dst > 4) { + error(this) << "too many components in swizzle operation\n"; + return sema.type_error(); + } + + uint32_t len_src = 0; + for (auto* p = str; *p; p++) { + uint32_t l; + switch (*p) { + case 'x': l = 0; break; + case 'y': l = 1; break; + case 'z': l = 2; break; + case 'w': l = 3; break; + default: + error(this) << "incorrect character in swizzle operation, only 'x', 'y', 'z', and 'w' are allowed\n"; + return sema.type_error(); + } + swizzle_.push_back(l); + len_src = std::max(len_src, l); + } + + if (len_src >= ltype->rows()) { + const char xyzw[] = "xyzw"; + error(this) << "vector component '" << xyzw[len_src] << "' is out of bounds\n"; + return sema.type_error(); + } + + return len_dst > 1 ? static_cast(sema.matrix_type(ltype->elem_type(), len_dst)) : ltype->elem_type(); + } + + error(this) << "matrix components cannot be accessed by field names, use the array syntax instead\n"; + return sema.type_error(); +} + Type MapExpr::check(TypeSema& sema, TypeExpectation expected) const { if (auto field_expr = lhs()->isa()) { if (field_expr->check_as_struct(sema, sema.unknown_type())) @@ -1204,7 +1306,7 @@ Type MapExpr::check_as_method_call(TypeSema& sema, TypeExpectation expected) con return sema.type_error(); } -Type VectorExpr::check(TypeSema& sema, TypeExpectation expected) const { +Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { if (!num_args()) error(this) << "arguments expected\n"; @@ -1212,29 +1314,122 @@ Type VectorExpr::check(TypeSema& sema, TypeExpectation expected) const { sema.check(arg); } + // Check that the number and type of arguments is correct + uint32_t rows, cols; switch (kind()) { - case VEC3: - if (!check_vector_args(sema)) return sema.type_error(); - return sema.matrix_type(arg(0)->type(), 3); - break; - default: break; +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok : rows = r; cols = c; break; +#include "impala/tokenlist.h" + default: THORIN_UNREACHABLE; } - return sema.type_error(); + + if (rows > 0) { + if (cols > 1) { return check_matrix_args(sema, rows, cols); } + else { return check_vector_args(sema, rows); } + } else { + // Special functions (e.g. dot, cross, ...) have rows = cols = 0 + switch (kind()) { + /*case Token::MAT_INVERSE: return check_vector_op(sema, 1); + case Token::VEC_DOT: return check_vector_op(sema, 2, true); + case Token::VEC_LENGTH: return check_vector_op(sema, 1, true); + case Token::VEC_NORMALIZE: return check_vector_op(sema, 1, false); + case Token::VEC_CROSS: return check_vector_op(sema, 2, false);*/ + default: break; + } + } + + THORIN_UNREACHABLE; } -bool VectorExpr::check_vector_args(TypeSema& sema) const { - auto arg0 = arg(0); +Type MatrixExpr::check_vector_args(TypeSema& sema, uint32_t rows) const { + // vectors can be constructed by concatenating scalars with smaller vectors + auto type = arg(0)->type(); + + Type elem_type; + if (auto mat = type.isa()) + elem_type = mat->elem_type(); + else + elem_type = type; + + if (!sema.is_int(elem_type) && !sema.is_float(elem_type)) { + error(this) << "only floating point and integer types are supported for vectors\n"; + return sema.type_error(); + } + + uint32_t count = 0; for (auto arg : args()) { - if (arg->type() != arg0->type()) { - error(this) << "mismatching types in vector expression\n"; - return false; + if (arg->type().isa() || arg->type().isa()) { + if (arg->type() != elem_type) { + error(this) << "mismatching types in vector initializers\n"; + return sema.type_error(); + } + count += 1; + } else if (auto mat = arg->type().isa()) { + if (mat->elem_type() != elem_type) { + error(this) << "mismatching types in vector initializers: got " + << mat->elem_type() << ", expected " << elem_type << "\n"; + return sema.type_error(); + } + + if (!mat->is_vector()) { + error(this) << "matrices are not allowed in vector initializers\n"; + return sema.type_error(); + } + count += mat->rows(); + } else { + error(this) << "incorrect type for vector initializer\n"; + return sema.type_error(); } - if (!arg->type().isa() && !arg->type().isa()) { - error(this) << "incorrect type for vector element\n"; - return false; + } + + if (count != rows && count > 1) { + error(this) << "incorrect number of initializers for vector: got " + << count << ", expected " << rows << "\n"; + return sema.type_error(); + } + + return sema.matrix_type(elem_type, rows); +} + +Type MatrixExpr::check_matrix_args(TypeSema& sema, uint32_t rows, uint32_t cols) const { + // matrices can be constructed by providing the columns as vectors, or by listing all the components one by one. + auto type = arg(0)->type(); + for (auto arg : args()) { + if (arg->type() != type) { + error(this) << "mismatching types in matrix initializers\n"; + return sema.type_error(); } } - return true; + + if (auto mat = type.isa()) { + if (!mat->is_vector()) { + error(this) << "matrices cannot be initialized with matrices\n"; + return sema.type_error(); + } + + if (num_args() != cols || mat->rows() != rows) { + error(this) << "matrix initializer sizes do not match matrix dimensions\n"; + return sema.type_error(); + } + + return sema.matrix_type(mat->elem_type(), rows, cols); + } + + if (!type.isa() && !type.isa()) { + error(this) << "incorrect type for matrix initializer\n"; + return sema.type_error(); + } + + if (!sema.is_int(type) && !sema.is_float(type)) { + error(this) << "only floating point and integer types are supported for matrices\n"; + return sema.type_error(); + } + + if (num_args() > 1 && num_args() != rows * cols) { + error(this) << "incorrect number of initializers for vector: got " + << num_args() << ", expected " << rows * cols << "\n"; + } + + return sema.matrix_type(type, rows, cols); } Type BlockExprBase::check(TypeSema& sema, TypeExpectation expected) const { diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index 9f5e06a65..f345a25a0 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -671,6 +671,7 @@ class MatrixTypeNode : public ArrayTypeNode { uint32_t rows() const { return rows_; } uint32_t cols() const { return cols_; } + uint32_t size() const { return rows_ * cols_; } bool is_vector() const { return cols_ == 1; } diff --git a/src/impala/stream.cpp b/src/impala/stream.cpp index 945d1f20c..cbe5a08e1 100644 --- a/src/impala/stream.cpp +++ b/src/impala/stream.cpp @@ -392,7 +392,7 @@ std::ostream& SimdExpr::stream(std::ostream& os) const { return stream_list(os, args(), [&](const Expr* expr) { os << expr; }, "simd[", "]"); } -std::ostream& VectorExpr::stream(std::ostream& os) const { +std::ostream& MatrixExpr::stream(std::ostream& os) const { switch (kind()) { #define IMPALA_KEY_VEC(tok, str) case tok : os << str; break; #include "impala/tokenlist.h" diff --git a/src/impala/tokenlist.h b/src/impala/tokenlist.h index da0177129..132c1a5e5 100644 --- a/src/impala/tokenlist.h +++ b/src/impala/tokenlist.h @@ -104,30 +104,30 @@ IMPALA_KEY(TYPEOF, "typeof") IMPALA_KEY(WHILE, "while") IMPALA_KEY(SIMD, "simd") -#ifndef IMPALA_VEC_KEY -#define IMPALA_VEC_KEY(tok, str) IMPALA_KEY(tok, str) +#ifndef IMPALA_MAT_KEY +#define IMPALA_MAT_KEY(tok, str, r, c) IMPALA_KEY(tok, str) #endif -IMPALA_VEC_KEY(VEC2, "vec2") -IMPALA_VEC_KEY(VEC3, "vec3") -IMPALA_VEC_KEY(VEC4, "vec4") -IMPALA_VEC_KEY(MAT2, "mat2") -IMPALA_VEC_KEY(MAT3, "mat3") -IMPALA_VEC_KEY(MAT4, "mat4") -IMPALA_VEC_KEY(MAT2X3, "mat2x3") -IMPALA_VEC_KEY(MAT2X4, "mat2x4") -IMPALA_VEC_KEY(MAT3X2, "mat3x2") -IMPALA_VEC_KEY(MAT3X4, "mat3x4") -IMPALA_VEC_KEY(MAT4X2, "mat4x2") -IMPALA_VEC_KEY(MAT4X3, "mat4x3") - -IMPALA_VEC_KEY(MAT_INVERSE, "inverse") -IMPALA_VEC_KEY(VEC_DOT, "dot") -IMPALA_VEC_KEY(VEC_CROSS, "cross") -IMPALA_VEC_KEY(VEC_LENGTH, "length") -IMPALA_VEC_KEY(VEC_NORMALIZE, "normalize") - -#undef IMPALA_VEC_KEY +IMPALA_MAT_KEY(VEC2, "vec2", 2, 1) +IMPALA_MAT_KEY(VEC3, "vec3", 3, 1) +IMPALA_MAT_KEY(VEC4, "vec4", 4, 1) +IMPALA_MAT_KEY(MAT2, "mat2", 2, 2) +IMPALA_MAT_KEY(MAT3, "mat3", 3, 3) +IMPALA_MAT_KEY(MAT4, "mat4", 4, 4) +IMPALA_MAT_KEY(MAT2X3, "mat2x3", 2, 3) +IMPALA_MAT_KEY(MAT2X4, "mat2x4", 2, 4) +IMPALA_MAT_KEY(MAT3X2, "mat3x2", 3, 2) +IMPALA_MAT_KEY(MAT3X4, "mat3x4", 3, 4) +IMPALA_MAT_KEY(MAT4X2, "mat4x2", 4, 2) +IMPALA_MAT_KEY(MAT4X3, "mat4x3", 4, 3) + +IMPALA_MAT_KEY(MAT_INVERSE, "inverse", 0, 0) +IMPALA_MAT_KEY(VEC_DOT, "dot", 0, 0) +IMPALA_MAT_KEY(VEC_CROSS, "cross", 0, 0) +IMPALA_MAT_KEY(VEC_LENGTH, "length", 0, 0) +IMPALA_MAT_KEY(VEC_NORMALIZE, "normalize", 0, 0) + +#undef IMPALA_MAT_KEY #undef IMPALA_KEY #ifndef IMPALA_MISC From 9ed419894a62bedb6a8be76736213d511e5b139f Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Thu, 13 Oct 2016 19:28:58 +0200 Subject: [PATCH 10/17] Component-wise operations --- src/impala/ast.h | 7 ++++++ src/impala/emit.cpp | 31 +++++++++++++++++++------ src/impala/sema/typesema.cpp | 45 +++++++++++++++++++----------------- 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index bf60b9a34..e2ffa14d2 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1117,11 +1117,18 @@ class InfixExpr : public Expr { virtual Type check(TypeSema&, TypeExpectation) const override; Type check_arith_op(TypeSema&) const; + const thorin::Def* emit(CodeGen&, TokenKind, const thorin::Def*, const thorin::Def*) const; Kind kind_; AutoPtr lhs_; AutoPtr rhs_; + mutable enum VecOp { + SCALAR, + VECTOR, + MATRIX + } lvec_, rvec_; + friend class Parser; }; diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index 3a7193cf0..61e4dfc16 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -470,16 +470,13 @@ const Def* InfixExpr::remit(CodeGen& cg) const { return cg.converge(this, x); } default: - const TokenKind op = (TokenKind) kind(); - + auto op = (TokenKind)kind(); if (Token::is_assign(op)) { Value lvar = cg.lemit(lhs()); const Def* rdef = cg.remit(rhs()); - if (op != Token::ASGN) { - TokenKind sop = Token::separate_assign(op); - rdef = cg.world().binop(Token::to_binop(sop), lvar.load(loc()), rdef, loc()); - } + if (op != Token::ASGN) + rdef = emit(cg, Token::separate_assign(op), lvar.load(loc()), rdef); lvar.store(rdef, loc()); return cg.world().tuple({}, loc()); @@ -487,10 +484,30 @@ const Def* InfixExpr::remit(CodeGen& cg) const { const Def* ldef = cg.remit(lhs()); const Def* rdef = cg.remit(rhs()); - return cg.world().binop(Token::to_binop(op), ldef, rdef, loc()); + return emit(cg, op, ldef, rdef); } } +const Def* InfixExpr::emit(CodeGen& cg, TokenKind op, const Def* ldef, const Def* rdef) const { + if (lvec_ == SCALAR && rvec_ == SCALAR) + return cg.world().binop(Token::to_binop(op), ldef, rdef, loc()); + + typedef std::function ExtractFn; + + auto lextract = lvec_ == SCALAR ? + ExtractFn([&] (int i) { return ldef; }) : + ExtractFn([&] (int i) { return cg.world().extract(ldef, i, loc()); }); + auto rextract = rvec_ == SCALAR ? + ExtractFn([&] (int i) { return rdef; }) : + ExtractFn([&] (int i) { return cg.world().extract(rdef, i, loc()); }); + + int n = (lvec_ != SCALAR ? lhs() : rhs())->type().as()->size(); + Array defs(n); + for (int i = 0; i < n; i++) + defs[i] = cg.world().binop(Token::to_binop(op), lextract(i), rextract(i), loc()); + return cg.world().definite_array(defs, loc()); +} + const Def* PostfixExpr::remit(CodeGen& cg) const { Value var = cg.lemit(lhs()); const Def* def = var.load(loc()); diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index f26ecd089..465ecabe3 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -913,49 +913,52 @@ inline Type matrix_elem_type(Type t) { } Type InfixExpr::check_arith_op(TypeSema& sema) const { - auto a = sema.check(lhs()); - auto b = sema.check(rhs()); + auto ltype = sema.check(lhs()); + auto rtype = sema.check(rhs()); + + auto lelem = matrix_elem_type(ltype); + auto relem = matrix_elem_type(rtype); + bool lscalar = ltype == lelem; + bool rscalar = rtype == relem; - auto elem_a = matrix_elem_type(a); - auto elem_b = matrix_elem_type(b); - bool scalar_a = a == elem_a; - bool scalar_b = b == elem_b; + lvec_ = lscalar ? SCALAR : (ltype.as()->is_vector() ? VECTOR : MATRIX); + rvec_ = rscalar ? SCALAR : (ltype.as()->is_vector() ? VECTOR : MATRIX); if (kind() == REM) { // For REM, the types must be integers - if (!sema.is_int(elem_a) || !sema.is_int(elem_b)) { + if (!sema.is_int(lelem) || !sema.is_int(relem)) { error(this) << "the modulus '%' operator is only valid on integers\n"; return sema.type_error(); } } else { - sema.expect_num(elem_a, lhs()); - sema.expect_num(elem_b, rhs()); + sema.expect_num(lelem, lhs()); + sema.expect_num(relem, rhs()); } // if operands are both scalars or both vectors, types must be equal - if ((scalar_a & scalar_b) && a == b) return a; - if ((scalar_a ^ scalar_b) && elem_a == elem_b) { return scalar_a ? b : a; } + if ((lscalar & rscalar) && ltype == rtype) return ltype; + if ((lscalar ^ rscalar) && lelem == relem) { return lscalar ? rtype : ltype; } - if (!scalar_a && !scalar_b && elem_a == elem_b) { - auto mat_a = a.as(); - auto mat_b = b.as(); + if (!lscalar && !rscalar && lelem == relem) { + auto lmat = ltype.as(); + auto rmat = rtype.as(); if (kind() == MUL) { // vector * matrix, matrix * vector, or matrix * matrix multiplication - if (!mat_a->is_vector()) { - if (mat_a->cols() == mat_b->rows()) return mat_b; + if (!lmat->is_vector()) { + if (lmat->cols() == rmat->rows()) return rmat; } else { - if (mat_a->rows() == mat_b->rows()) return mat_a; + if (lmat->rows() == rmat->rows()) return lmat; } - } else if (mat_a->rows() == mat_b->rows()) { + } else if (lmat->rows() == rmat->rows()) { // addition, subtraction, division, ... - return mat_a; + return lmat; } } - if (!a->is_error() && !b->is_error()) { + if (!ltype->is_error() && !rtype->is_error()) { error(this) << "types do not match for operator '" << Token(loc(), (Token::Kind)kind()) - << "', got " << a << " and " << b << " \n"; + << "', got " << ltype << " and " << rtype << " \n"; } return sema.type_error(); } From 46b4184f99152e569cfe84b410eff9d933b367c5 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 14 Oct 2016 14:57:38 +0200 Subject: [PATCH 11/17] Matrix multiplication --- src/impala/emit.cpp | 35 +++++++++++++++++++++++++++++++---- src/impala/parser.cpp | 28 +++++++++------------------- src/impala/tokenlist.h | 8 +++----- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index 61e4dfc16..c009cbbfe 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -492,8 +492,35 @@ const Def* InfixExpr::emit(CodeGen& cg, TokenKind op, const Def* ldef, const Def if (lvec_ == SCALAR && rvec_ == SCALAR) return cg.world().binop(Token::to_binop(op), ldef, rdef, loc()); - typedef std::function ExtractFn; + if (kind() == MUL && lvec_ != SCALAR && rvec_ != SCALAR) { + auto lmat = lhs()->type().as(); + auto rmat = rhs()->type().as(); + if (!lmat->is_vector() || !rmat->is_vector()) { + // matrix-vector or vector-matrix multiplication + bool transpose = lmat->is_vector(); + int rows = transpose ? lmat->cols() : lmat->rows(); + int cols = rmat->cols(); + int lrows = lmat->rows(); + int rrows = rmat->rows(); + Array defs(rows * cols); + for (int i = 0; i < cols; i++) { + for (int j = 0; j < rows; j++) { + const Def* sum = nullptr; + for (int k = 0, n = rmat->rows(); k < n; k++) { + auto mul = cg.world().binop(ArithOp_mul, + cg.world().extract(ldef, transpose ? j : k * lrows + j, loc()), + cg.world().extract(rdef, i * rrows + k, loc()), loc()); + sum = sum ? cg.world().binop(ArithOp_add, sum, mul, loc()) : mul; + } + defs[transpose ? i : i * rows + j] = sum; + } + } + return cg.world().definite_array(defs, loc()); + } + } + // component-wise ops and scalar vs. matrix/vector ops are handled here + typedef std::function ExtractFn; auto lextract = lvec_ == SCALAR ? ExtractFn([&] (int i) { return ldef; }) : ExtractFn([&] (int i) { return cg.world().extract(ldef, i, loc()); }); @@ -603,9 +630,9 @@ const Def* MatrixExpr::remit(CodeGen& cg) const { case VEC2: case VEC3: case VEC4: - case MAT2: - case MAT3: - case MAT4: + case MAT2X2: + case MAT3X3: + case MAT4X4: case MAT2X3: case MAT2X4: case MAT3X2: diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index be28507de..58ad49d3f 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -73,9 +73,9 @@ case Token::VEC2: \ case Token::VEC3: \ case Token::VEC4: \ - case Token::MAT2: \ - case Token::MAT3: \ - case Token::MAT4: \ + case Token::MAT2X2: \ + case Token::MAT3X3: \ + case Token::MAT4X4: \ case Token::MAT2X3: \ case Token::MAT2X4: \ case Token::MAT3X2: \ @@ -727,9 +727,9 @@ const ASTType* Parser::parse_type() { case Token::VEC2: case Token::VEC3: case Token::VEC4: - case Token::MAT2: - case Token::MAT3: - case Token::MAT4: + case Token::MAT2X2: + case Token::MAT3X3: + case Token::MAT4X4: case Token::MAT2X3: case Token::MAT2X4: case Token::MAT3X2: @@ -886,22 +886,12 @@ const SimdASTType* Parser::parse_simd_type() { const MatrixASTType* Parser::parse_matrix_type() { auto mat = loc(new MatrixASTType()); - mat->cols_ = 1; switch (la()) { - case Token::VEC2: mat->rows_ = 2; mat->cols_ = 1; break; - case Token::VEC3: mat->rows_ = 3; mat->cols_ = 1; break; - case Token::VEC4: mat->rows_ = 4; mat->cols_ = 1; break; - case Token::MAT2: mat->rows_ = 2; mat->cols_ = 2; break; - case Token::MAT3: mat->rows_ = 2; mat->cols_ = 3; break; - case Token::MAT4: mat->rows_ = 2; mat->cols_ = 4; break; - case Token::MAT2X3: mat->rows_ = 2; mat->cols_ = 3; break; - case Token::MAT2X4: mat->rows_ = 2; mat->cols_ = 4; break; - case Token::MAT3X2: mat->rows_ = 3; mat->cols_ = 2; break; - case Token::MAT3X4: mat->rows_ = 3; mat->cols_ = 4; break; - case Token::MAT4X2: mat->rows_ = 4; mat->cols_ = 2; break; - case Token::MAT4X3: mat->rows_ = 4; mat->cols_ = 3; break; +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok : mat->rows_ = r; mat->cols_ = c; break; +#include "impala/tokenlist.h" default: THORIN_UNREACHABLE; } + assert(mat->rows_ > 0 && mat->cols_ > 0); lex(); expect(Token::L_BRACKET, "vector or matrix type"); mat->elem_type_ = parse_type(); diff --git a/src/impala/tokenlist.h b/src/impala/tokenlist.h index 132c1a5e5..d4c42e92b 100644 --- a/src/impala/tokenlist.h +++ b/src/impala/tokenlist.h @@ -111,9 +111,9 @@ IMPALA_KEY(SIMD, "simd") IMPALA_MAT_KEY(VEC2, "vec2", 2, 1) IMPALA_MAT_KEY(VEC3, "vec3", 3, 1) IMPALA_MAT_KEY(VEC4, "vec4", 4, 1) -IMPALA_MAT_KEY(MAT2, "mat2", 2, 2) -IMPALA_MAT_KEY(MAT3, "mat3", 3, 3) -IMPALA_MAT_KEY(MAT4, "mat4", 4, 4) +IMPALA_MAT_KEY(MAT2X2, "mat2x2", 2, 2) +IMPALA_MAT_KEY(MAT3X3, "mat3x3", 3, 3) +IMPALA_MAT_KEY(MAT4X4, "mat4x4", 4, 4) IMPALA_MAT_KEY(MAT2X3, "mat2x3", 2, 3) IMPALA_MAT_KEY(MAT2X4, "mat2x4", 2, 4) IMPALA_MAT_KEY(MAT3X2, "mat3x2", 3, 2) @@ -124,8 +124,6 @@ IMPALA_MAT_KEY(MAT4X3, "mat4x3", 4, 3) IMPALA_MAT_KEY(MAT_INVERSE, "inverse", 0, 0) IMPALA_MAT_KEY(VEC_DOT, "dot", 0, 0) IMPALA_MAT_KEY(VEC_CROSS, "cross", 0, 0) -IMPALA_MAT_KEY(VEC_LENGTH, "length", 0, 0) -IMPALA_MAT_KEY(VEC_NORMALIZE, "normalize", 0, 0) #undef IMPALA_MAT_KEY #undef IMPALA_KEY From c398ae4430ba93a2def6fb01bc8d6e2b2e5c26d0 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 14 Oct 2016 15:01:41 +0200 Subject: [PATCH 12/17] Bug fix in type checks --- src/impala/sema/typesema.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 465ecabe3..74bcafc98 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -916,6 +916,8 @@ Type InfixExpr::check_arith_op(TypeSema& sema) const { auto ltype = sema.check(lhs()); auto rtype = sema.check(rhs()); + if (ltype->is_error() || rtype->is_error()) return sema.type_error(); + auto lelem = matrix_elem_type(ltype); auto relem = matrix_elem_type(rtype); bool lscalar = ltype == lelem; From 18e1fa7d78442e6ddd594679baa5af6ff4f089c9 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 14 Oct 2016 15:27:50 +0200 Subject: [PATCH 13/17] Implemented dot + cross --- src/impala/ast.h | 8 +++++ src/impala/emit.cpp | 34 ++++++++++++++++++++ src/impala/parser.cpp | 5 ++- src/impala/sema/typesema.cpp | 61 ++++++++++++++++++++++++++++++++---- 4 files changed, 101 insertions(+), 7 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index e2ffa14d2..31fa7abdf 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1385,6 +1385,14 @@ class MatrixExpr : public Expr, public Args { Type check_vector_args(TypeSema&, uint32_t) const; Type check_matrix_args(TypeSema&, uint32_t, uint32_t) const; + Type check_inverse(TypeSema&) const; + Type check_cross(TypeSema&) const; + Type check_dot(TypeSema&) const; + + const thorin::Def* emit_inverse(CodeGen&, const thorin::Def*) const; + const thorin::Def* emit_cross(CodeGen&, const thorin::Def*, const thorin::Def*) const; + const thorin::Def* emit_dot(CodeGen&, const thorin::Def*, const thorin::Def*) const; + friend class CodeGen; friend class Parser; friend class TypeSema; diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index c009cbbfe..8159c7191 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -667,11 +667,45 @@ const Def* MatrixExpr::remit(CodeGen& cg) const { } return cg.world().definite_array(defs, loc()); } + case MAT_INVERSE: return emit_inverse(cg, cg.remit(arg(0))); break; + case VEC_CROSS: return emit_cross(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; + case VEC_DOT: return emit_dot(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; default: break; } THORIN_UNREACHABLE; } +const Def* MatrixExpr::emit_inverse(CodeGen& cg, const Def* def) const { + THORIN_UNREACHABLE; +} + +const Def* MatrixExpr::emit_cross(CodeGen& cg, const Def* ldef, const Def* rdef) const { + Array defs(3); + for (int i = 0; i < 3; i++) { + int j = (i + 1) % 3, k = (i + 2) % 3; + defs[i] = cg.world().binop(ArithOp_sub, + cg.world().binop(ArithOp_mul, + cg.world().extract(ldef, j, loc()), + cg.world().extract(rdef, k, loc()), loc()), + cg.world().binop(ArithOp_mul, + cg.world().extract(ldef, k, loc()), + cg.world().extract(rdef, j, loc()), loc()), loc()); + } + return cg.world().definite_array(defs, loc()); +} + +const Def* MatrixExpr::emit_dot(CodeGen& cg, const Def* ldef, const Def* rdef) const { + int n = arg(0)->type().as()->rows(); + const Def* sum = nullptr; + for (int i = 0; i < n; i++) { + auto mul = cg.world().binop(ArithOp_mul, + cg.world().extract(ldef, i, loc()), + cg.world().extract(rdef, i, loc()), loc()); + sum = sum ? cg.world().binop(ArithOp_add, sum, mul, loc()) : mul; + } + return sum; +} + Value FieldExpr::lemit(CodeGen& cg) const { assert(!lhs()->type().isa()); return Value::create_agg(cg.lemit(lhs()), cg.world().literal_qu32(index(), loc())); diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index 58ad49d3f..589f77a4b 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -81,7 +81,10 @@ case Token::MAT3X2: \ case Token::MAT3X4: \ case Token::MAT4X2: \ - case Token::MAT4X3 + case Token::MAT4X3: \ + case Token::MAT_INVERSE: \ + case Token::VEC_CROSS: \ + case Token::VEC_DOT #define STMT_NOT_EXPR \ Token::LET: \ diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 74bcafc98..05ea00d15 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -1316,7 +1316,8 @@ Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { error(this) << "arguments expected\n"; for (auto arg : args()) { - sema.check(arg); + // prevent error message explosion + if (sema.check(arg)->is_error()) return sema.type_error(); } // Check that the number and type of arguments is correct @@ -1333,11 +1334,9 @@ Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { } else { // Special functions (e.g. dot, cross, ...) have rows = cols = 0 switch (kind()) { - /*case Token::MAT_INVERSE: return check_vector_op(sema, 1); - case Token::VEC_DOT: return check_vector_op(sema, 2, true); - case Token::VEC_LENGTH: return check_vector_op(sema, 1, true); - case Token::VEC_NORMALIZE: return check_vector_op(sema, 1, false); - case Token::VEC_CROSS: return check_vector_op(sema, 2, false);*/ + case Token::MAT_INVERSE: return check_inverse(sema); + case Token::VEC_DOT: return check_dot(sema); + case Token::VEC_CROSS: return check_cross(sema); default: break; } } @@ -1437,6 +1436,56 @@ Type MatrixExpr::check_matrix_args(TypeSema& sema, uint32_t rows, uint32_t cols) return sema.matrix_type(type, rows, cols); } +Type MatrixExpr::check_inverse(TypeSema& sema) const { + if (num_args() != 1) { + error(this) << "incorrect number of arguments for the inverse function\n"; + return sema.type_error(); + } + + auto mat = arg(0)->type().isa(); + + if (!mat || mat->is_vector() || mat->rows() != mat->cols()) { + error(this) << "invalid operand type for the inverse function\n"; + return sema.type_error(); + } + + return mat; +} + +Type MatrixExpr::check_cross(TypeSema& sema) const { + if (num_args() != 2) { + error(this) << "incorrect number of arguments for the cross function\n"; + return sema.type_error(); + } + + auto lmat = arg(0)->type().isa(); + auto rmat = arg(1)->type().isa(); + + if (lmat != rmat || !lmat || !lmat->is_vector() || lmat->rows() != 3) { + error(this) << "invalid operand types for the cross function\n"; + return sema.type_error(); + } + + return sema.matrix_type(lmat->elem_type(), 3); +} + +Type MatrixExpr::check_dot(TypeSema& sema) const { + if (num_args() != 2) { + error(this) << "incorrect number of arguments for the dot function\n"; + return sema.type_error(); + } + + auto lmat = arg(0)->type().isa(); + auto rmat = arg(1)->type().isa(); + + if (lmat != rmat || !lmat || !lmat->is_vector()) { + error(this) << "invalid operand types for the dot function\n"; + return sema.type_error(); + } + + return lmat->elem_type(); +} + Type BlockExprBase::check(TypeSema& sema, TypeExpectation expected) const { THORIN_PUSH(sema.cur_block_, this); for (auto stmt : stmts()) From 3c2a9f4f0bd53997ec2ef12a517f6764491d1d19 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Fri, 14 Oct 2016 16:24:41 +0200 Subject: [PATCH 14/17] Fix comments --- src/impala/sema/typesema.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 05ea00d15..c598f41d7 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -1320,7 +1320,6 @@ Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { if (sema.check(arg)->is_error()) return sema.type_error(); } - // Check that the number and type of arguments is correct uint32_t rows, cols; switch (kind()) { #define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok : rows = r; cols = c; break; @@ -1332,7 +1331,7 @@ Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { if (cols > 1) { return check_matrix_args(sema, rows, cols); } else { return check_vector_args(sema, rows); } } else { - // Special functions (e.g. dot, cross, ...) have rows = cols = 0 + // special functions (e.g. dot, cross, ...) have rows = cols = 0 switch (kind()) { case Token::MAT_INVERSE: return check_inverse(sema); case Token::VEC_DOT: return check_dot(sema); From f0ee2993f17fc99030b758102bb9b0474c760d3a Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 18 Oct 2016 14:23:58 +0200 Subject: [PATCH 15/17] Matrix inverse + determinant --- src/impala/ast.h | 2 + src/impala/emit.cpp | 87 +++++++++++++++++++++++++++++++----- src/impala/parser.cpp | 1 + src/impala/sema/typesema.cpp | 25 +++++++++-- src/impala/tokenlist.h | 7 +-- 5 files changed, 104 insertions(+), 18 deletions(-) diff --git a/src/impala/ast.h b/src/impala/ast.h index 31fa7abdf..279b53c69 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -1386,10 +1386,12 @@ class MatrixExpr : public Expr, public Args { Type check_matrix_args(TypeSema&, uint32_t, uint32_t) const; Type check_inverse(TypeSema&) const; + Type check_determinant(TypeSema&) const; Type check_cross(TypeSema&) const; Type check_dot(TypeSema&) const; const thorin::Def* emit_inverse(CodeGen&, const thorin::Def*) const; + const thorin::Def* emit_determinant(CodeGen&, const thorin::Def*) const; const thorin::Def* emit_cross(CodeGen&, const thorin::Def*, const thorin::Def*) const; const thorin::Def* emit_dot(CodeGen&, const thorin::Def*, const thorin::Def*) const; diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index 8159c7191..e887929af 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -507,10 +507,10 @@ const Def* InfixExpr::emit(CodeGen& cg, TokenKind op, const Def* ldef, const Def for (int j = 0; j < rows; j++) { const Def* sum = nullptr; for (int k = 0, n = rmat->rows(); k < n; k++) { - auto mul = cg.world().binop(ArithOp_mul, + auto mul = cg.world().arithop_mul( cg.world().extract(ldef, transpose ? j : k * lrows + j, loc()), cg.world().extract(rdef, i * rrows + k, loc()), loc()); - sum = sum ? cg.world().binop(ArithOp_add, sum, mul, loc()) : mul; + sum = sum ? cg.world().arithop_add(sum, mul, loc()) : mul; } defs[transpose ? i : i * rows + j] = sum; } @@ -667,27 +667,92 @@ const Def* MatrixExpr::remit(CodeGen& cg) const { } return cg.world().definite_array(defs, loc()); } - case MAT_INVERSE: return emit_inverse(cg, cg.remit(arg(0))); break; - case VEC_CROSS: return emit_cross(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; - case VEC_DOT: return emit_dot(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; + case MAT_INVERSE: return emit_inverse(cg, cg.remit(arg(0))); break; + case MAT_DETERMINANT: return emit_determinant(cg, cg.remit(arg(0))); break; + case VEC_CROSS: return emit_cross(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; + case VEC_DOT: return emit_dot(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; default: break; } THORIN_UNREACHABLE; } +static const Def* submatrix(CodeGen& cg, int n, int row, int col, const Def* def, const Location& loc) { + // removes row and col from the given square matrix + Array defs((n - 1) * (n - 1)); + for (int i = 0, p = 0; i < n; i++) { + if (i == col) continue; + for (int j = 0, q = 0; j < n; j++) { + if (j == row) continue; + defs[p * (n - 1) + q] = cg.world().extract(def, i * n + j, loc); + q++; + } + p++; + } + return cg.world().definite_array(defs, loc); +} + +static const Def* determinant(CodeGen& cg, int n, const Def* def, const Location& loc) { + if (n == 1) return cg.world().extract(def, 0, loc); + if (n == 2) { + return cg.world().arithop_sub( + cg.world().arithop_mul(cg.world().extract(def, 0, loc), + cg.world().extract(def, 3, loc), loc), + cg.world().arithop_mul(cg.world().extract(def, 1, loc), + cg.world().extract(def, 2, loc), loc), loc); + } + + const Def* sum = nullptr; + for (int i = 0; i < n; i++) { + auto mul = cg.world().arithop_mul( + cg.world().extract(def, i, loc), + determinant(cg, n - 1, submatrix(cg, n, i, 0, def, loc), loc), loc); + sum = sum ? cg.world().binop(i % 2 ? ArithOp_add : ArithOp_sub, sum, mul, loc) : mul; + } + return sum; +} + +const Def* MatrixExpr::emit_determinant(CodeGen& cg, const Def* def) const { + return determinant(cg, arg(0)->type().as()->rows(), def, loc()); +} + const Def* MatrixExpr::emit_inverse(CodeGen& cg, const Def* def) const { - THORIN_UNREACHABLE; + // the formula used is A.t(com A) = (det A).Id + auto mat = arg(0)->type().as(); + int n = mat->rows(); + + // 1. compute (det A) + auto det = determinant(cg, n, def, loc()); + + // 2. compute t(com A) + Array elems(n * n); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + auto subdet = determinant(cg, n - 1, submatrix(cg, n, i, j, def, loc()), loc()); + elems[i * n + j] = (i + j) % 2 ? cg.world().arithop_minus(subdet, loc()) : subdet; + } + } + + // 3. compute det A == 0 ? 0 : 1/detA + auto elem = cg.convert(mat->elem_type()); + auto zero = cg.world().zero(elem, loc()); + auto one = cg.world().literal(pf32(1.0f), loc()); + auto cmp = cg.world().cmp_eq(det, zero, loc()); + auto inv = cg.world().select(cmp, zero, cg.world().arithop_div(one, det, loc()), loc()); + + for (int i = 0; i < n * n; i++) elems[i] = cg.world().arithop_mul(elems[i], inv, loc()); + + return cg.world().definite_array(elems, loc()); } const Def* MatrixExpr::emit_cross(CodeGen& cg, const Def* ldef, const Def* rdef) const { Array defs(3); for (int i = 0; i < 3; i++) { int j = (i + 1) % 3, k = (i + 2) % 3; - defs[i] = cg.world().binop(ArithOp_sub, - cg.world().binop(ArithOp_mul, + defs[i] = cg.world().arithop_sub( + cg.world().arithop_mul( cg.world().extract(ldef, j, loc()), cg.world().extract(rdef, k, loc()), loc()), - cg.world().binop(ArithOp_mul, + cg.world().arithop_mul( cg.world().extract(ldef, k, loc()), cg.world().extract(rdef, j, loc()), loc()), loc()); } @@ -698,10 +763,10 @@ const Def* MatrixExpr::emit_dot(CodeGen& cg, const Def* ldef, const Def* rdef) c int n = arg(0)->type().as()->rows(); const Def* sum = nullptr; for (int i = 0; i < n; i++) { - auto mul = cg.world().binop(ArithOp_mul, + auto mul = cg.world().arithop_mul( cg.world().extract(ldef, i, loc()), cg.world().extract(rdef, i, loc()), loc()); - sum = sum ? cg.world().binop(ArithOp_add, sum, mul, loc()) : mul; + sum = sum ? cg.world().arithop_add(sum, mul, loc()) : mul; } return sum; } diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index 589f77a4b..cf4c82fbe 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -83,6 +83,7 @@ case Token::MAT4X2: \ case Token::MAT4X3: \ case Token::MAT_INVERSE: \ + case Token::MAT_DETERMINANT: \ case Token::VEC_CROSS: \ case Token::VEC_DOT diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index c598f41d7..127d18f8c 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -1333,9 +1333,10 @@ Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { } else { // special functions (e.g. dot, cross, ...) have rows = cols = 0 switch (kind()) { - case Token::MAT_INVERSE: return check_inverse(sema); - case Token::VEC_DOT: return check_dot(sema); - case Token::VEC_CROSS: return check_cross(sema); + case Token::MAT_INVERSE: return check_inverse(sema); + case Token::MAT_DETERMINANT: return check_determinant(sema); + case Token::VEC_DOT: return check_dot(sema); + case Token::VEC_CROSS: return check_cross(sema); default: break; } } @@ -1443,7 +1444,7 @@ Type MatrixExpr::check_inverse(TypeSema& sema) const { auto mat = arg(0)->type().isa(); - if (!mat || mat->is_vector() || mat->rows() != mat->cols()) { + if (!mat || mat->is_vector() || mat->rows() != mat->cols() || !sema.is_float(mat->elem_type())) { error(this) << "invalid operand type for the inverse function\n"; return sema.type_error(); } @@ -1451,6 +1452,22 @@ Type MatrixExpr::check_inverse(TypeSema& sema) const { return mat; } +Type MatrixExpr::check_determinant(TypeSema& sema) const { + if (num_args() != 1) { + error(this) << "incorrect number of arguments for the determinant function\n"; + return sema.type_error(); + } + + auto mat = arg(0)->type().isa(); + + if (!mat || mat->is_vector() || mat->rows() != mat->cols()) { + error(this) << "invalid operand type for the determinant function\n"; + return sema.type_error(); + } + + return mat->elem_type(); +} + Type MatrixExpr::check_cross(TypeSema& sema) const { if (num_args() != 2) { error(this) << "incorrect number of arguments for the cross function\n"; diff --git a/src/impala/tokenlist.h b/src/impala/tokenlist.h index d4c42e92b..ec303496e 100644 --- a/src/impala/tokenlist.h +++ b/src/impala/tokenlist.h @@ -121,9 +121,10 @@ IMPALA_MAT_KEY(MAT3X4, "mat3x4", 3, 4) IMPALA_MAT_KEY(MAT4X2, "mat4x2", 4, 2) IMPALA_MAT_KEY(MAT4X3, "mat4x3", 4, 3) -IMPALA_MAT_KEY(MAT_INVERSE, "inverse", 0, 0) -IMPALA_MAT_KEY(VEC_DOT, "dot", 0, 0) -IMPALA_MAT_KEY(VEC_CROSS, "cross", 0, 0) +IMPALA_MAT_KEY(MAT_INVERSE, "inverse", 0, 0) +IMPALA_MAT_KEY(MAT_DETERMINANT, "determinant", 0, 0) +IMPALA_MAT_KEY(VEC_DOT, "dot", 0, 0) +IMPALA_MAT_KEY(VEC_CROSS, "cross", 0, 0) #undef IMPALA_MAT_KEY #undef IMPALA_KEY From 8f5ec9791f71818703283c1ce2d6e904ace01be6 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Tue, 18 Oct 2016 18:49:33 +0200 Subject: [PATCH 16/17] Added support for CGen --- src/impala/cgen.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/impala/cgen.cpp b/src/impala/cgen.cpp index 560c5538f..7bd6cc80b 100644 --- a/src/impala/cgen.cpp +++ b/src/impala/cgen.cpp @@ -163,6 +163,13 @@ class CGen { return true; } + if (auto matrix_type = type.isa()) { + if (!ctype_from_impala(matrix_type->elem_type(), ctype_prefix, ctype_suffix)) + return false; + ctype_suffix = "[" + std::to_string(matrix_type->size()) + "]"; + return true; + } + return false; } From 2c423dcb82657554b5d01ced7f57e7268abb3db5 Mon Sep 17 00:00:00 2001 From: Arsene Perard-Gayot Date: Wed, 19 Oct 2016 14:47:54 +0200 Subject: [PATCH 17/17] Bug fix in vector ops --- src/impala/ast.cpp | 6 +++- src/impala/cgen.cpp | 32 ++++++++++--------- src/impala/emit.cpp | 1 - src/impala/sema/typesema.cpp | 62 +++++++++++++++++++----------------- 4 files changed, 55 insertions(+), 46 deletions(-) diff --git a/src/impala/ast.cpp b/src/impala/ast.cpp index ea6544a09..060cc6719 100644 --- a/src/impala/ast.cpp +++ b/src/impala/ast.cpp @@ -86,7 +86,11 @@ bool MapExpr::is_lvalue() const { } bool PrefixExpr::is_lvalue() const { return kind() == MUL; } -bool FieldExpr::is_lvalue() const { return (lhs()->is_lvalue() && !lhs()->type().isa()) || lhs()->type().isa(); } +bool FieldExpr::is_lvalue() const { + if (lhs()->type().isa()) + return (strlen(symbol().str()) == 1) && lhs()->is_lvalue(); + return lhs()->is_lvalue() || lhs()->type().isa(); +} bool CastExpr::is_lvalue() const { return lhs()->is_lvalue(); } //------------------------------------------------------------------------------ diff --git a/src/impala/cgen.cpp b/src/impala/cgen.cpp index 7bd6cc80b..d9232ecc2 100644 --- a/src/impala/cgen.cpp +++ b/src/impala/cgen.cpp @@ -40,7 +40,7 @@ class CGen { } static std::ostream& cgen_error(const ASTNode* node) { - return error(node) << "cannot generate C interface : "; + return error(node) << "cannot generate C interface: "; } // Generates a C type from an Impala type @@ -48,28 +48,28 @@ class CGen { if (auto prim_type = type.isa()) { switch (prim_type->primtype_kind()) { case PrimType_i8: - ctype_prefix = "char"; ctype_suffix = ""; + ctype_prefix = "int8_t"; ctype_suffix = ""; return true; case PrimType_i16: - ctype_prefix = "short"; ctype_suffix = ""; + ctype_prefix = "int16_t"; ctype_suffix = ""; return true; case PrimType_i32: - ctype_prefix = "int"; ctype_suffix = ""; + ctype_prefix = "int32_t"; ctype_suffix = ""; return true; case PrimType_i64: - ctype_prefix = "long long"; ctype_suffix = ""; + ctype_prefix = "int64_t"; ctype_suffix = ""; return true; case PrimType_u8: - ctype_prefix = "unsigned char"; ctype_suffix = ""; + ctype_prefix = "uint8_t"; ctype_suffix = ""; return true; case PrimType_u16: - ctype_prefix = "unsigned short"; ctype_suffix = ""; + ctype_prefix = "uint16_t"; ctype_suffix = ""; return true; case PrimType_u32: - ctype_prefix = "unsigned int"; ctype_suffix = ""; + ctype_prefix = "uint32_t"; ctype_suffix = ""; return true; case PrimType_u64: - ctype_prefix = "unsigned long long"; ctype_suffix = ""; + ctype_prefix = "uint64_t"; ctype_suffix = ""; return true; case PrimType_f16: ctype_prefix = "half"; ctype_suffix = ""; @@ -136,14 +136,15 @@ class CGen { // &[T * N] -> T* // &T -> T* - if (auto array_type = ptr_type->referenced_type().isa()) { - if (!ctype_from_impala(array_type->elem_type(), ctype_prefix, ctype_suffix)) - return false; - } else { - if (!ctype_from_impala(ptr_type->referenced_type(), ctype_prefix, ctype_suffix)) - return false; + auto ref_type = ptr_type->referenced_type(); + if (auto array_type = ref_type.isa()) { + ref_type.clear(); + ref_type = array_type->elem_type(); } + if (!ctype_from_impala(ref_type, ctype_prefix, ctype_suffix)) + return false; + ctype_prefix += "*"; ctype_suffix = ""; return true; @@ -354,6 +355,7 @@ bool generate_c_interface(const ModContents* mod, const CGenOptions& opts, std:: o << "/* " << opts.file_name << " : Impala interface file generated by impala */\n" << "#ifndef " << opts.guard << "\n" << "#define " << opts.guard << "\n\n" + << "#include \n\n" << "#ifdef __cplusplus\n" << "extern \"C\" {\n" << "#endif\n" diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index e887929af..b3bc0bde4 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -772,7 +772,6 @@ const Def* MatrixExpr::emit_dot(CodeGen& cg, const Def* ldef, const Def* rdef) c } Value FieldExpr::lemit(CodeGen& cg) const { - assert(!lhs()->type().isa()); return Value::create_agg(cg.lemit(lhs()), cg.world().literal_qu32(index(), loc())); } diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 127d18f8c..f1e7d3695 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -60,7 +60,7 @@ class TypeSema : public TypeTable { } return true; } - Type scalar_type(const Expr*); + Type scalar_type(Type); bool is_int(Type t); bool is_float(Type t); bool expect_int(const Expr*); @@ -136,8 +136,7 @@ class TypeSema : public TypeTable { //------------------------------------------------------------------------------ -Type TypeSema::scalar_type(const Expr* e) { - Type t = e->type(); +Type TypeSema::scalar_type(Type t) { if (auto simd = t.isa()) { return simd->elem_type(); } @@ -154,8 +153,7 @@ bool TypeSema::is_float(Type t) { } bool TypeSema::expect_int(const Expr* expr) { - auto t = scalar_type(expr); - + auto t = scalar_type(expr->type()); if (!t->is_error() && !is_int(t)) { error(expr) << "expected integer type but found " << t << "\n"; return false; @@ -164,8 +162,7 @@ bool TypeSema::expect_int(const Expr* expr) { } bool TypeSema::expect_int_or_bool(const Expr* expr) { - auto t = scalar_type(expr); - + auto t = scalar_type(expr->type()); if (!t->is_error() && !t->is_bool() && !is_int(t)) { error(expr) << "expected integer or boolean type but found " << t << "\n"; return false; @@ -174,7 +171,8 @@ bool TypeSema::expect_int_or_bool(const Expr* expr) { } bool TypeSema::expect_num(Type t, const Expr* expr) { - if (!t->is_error() && !t->is_bool() && !is_int(t) && !is_float(t)) { + auto scalar = scalar_type(t); + if (!scalar->is_error() && !scalar->is_bool() && !is_int(scalar) && !is_float(scalar)) { error(expr) << "expected number type but found " << t << "\n"; return false; } @@ -182,7 +180,7 @@ bool TypeSema::expect_num(Type t, const Expr* expr) { } void TypeSema::expect_num(const Expr* expr) { - expect_num(scalar_type(expr), expr); + expect_num(expr->type(), expr); } Type TypeSema::expect_type(const Expr* expr, Type found_type, TypeExpectation expected) { @@ -373,17 +371,15 @@ Type SimdASTType::check(TypeSema& sema) const { Type MatrixASTType::check(TypeSema& sema) const { auto type = sema.check(elem_type()); - if (type.isa() || type.isa()) { - if (!sema.is_int(type) && !sema.is_float(type)) { - error(this) << "only floating point and integer types are supported for " - << (is_vector() ? "vectors" : "matrices") << "\n"; - return sema.type_error(); - } - return sema.matrix_type(type, rows(), cols()); + auto scalar = sema.scalar_type(type); + + if (!sema.is_int(scalar) && !sema.is_float(scalar)) { + error(this) << "only floating point and integer types are supported for " + << (is_vector() ? "vectors" : "matrices") << "\n"; + return sema.type_error(); } - error(this) << "vector or matrix types only support primitive or simd types\n"; - return sema.type_error(); + return sema.matrix_type(type, rows(), cols()); } //------------------------------------------------------------------------------ @@ -928,7 +924,7 @@ Type InfixExpr::check_arith_op(TypeSema& sema) const { if (kind() == REM) { // For REM, the types must be integers - if (!sema.is_int(lelem) || !sema.is_int(relem)) { + if (!sema.is_int(sema.scalar_type(lelem)) || !sema.is_int(sema.scalar_type(relem))) { error(this) << "the modulus '%' operator is only valid on integers\n"; return sema.type_error(); } @@ -1192,6 +1188,9 @@ Type FieldExpr::check_as_struct(TypeSema& sema, Type expected) const { ltype = sema.check(lhs()); } + if (auto mat = ltype.isa()) + return check_as_matrix(sema, expected); + if (auto struct_app = ltype.isa()) { if (auto field_decl = struct_app->struct_abs_type()->struct_decl()->field_decl(symbol())) { index_ = field_decl->index(); @@ -1236,6 +1235,8 @@ Type FieldExpr::check_as_matrix(TypeSema& sema, Type expected) const { return sema.type_error(); } + if (len_dst == 1) index_ = len_src; + return len_dst > 1 ? static_cast(sema.matrix_type(ltype->elem_type(), len_dst)) : ltype->elem_type(); } @@ -1354,7 +1355,8 @@ Type MatrixExpr::check_vector_args(TypeSema& sema, uint32_t rows) const { else elem_type = type; - if (!sema.is_int(elem_type) && !sema.is_float(elem_type)) { + auto scalar = sema.scalar_type(elem_type); + if (!sema.is_int(scalar) && !sema.is_float(scalar)) { error(this) << "only floating point and integer types are supported for vectors\n"; return sema.type_error(); } @@ -1418,18 +1420,14 @@ Type MatrixExpr::check_matrix_args(TypeSema& sema, uint32_t rows, uint32_t cols) return sema.matrix_type(mat->elem_type(), rows, cols); } - if (!type.isa() && !type.isa()) { - error(this) << "incorrect type for matrix initializer\n"; - return sema.type_error(); - } - - if (!sema.is_int(type) && !sema.is_float(type)) { + auto scalar = sema.scalar_type(type); + if (!sema.is_int(scalar) && !sema.is_float(scalar)) { error(this) << "only floating point and integer types are supported for matrices\n"; return sema.type_error(); } if (num_args() > 1 && num_args() != rows * cols) { - error(this) << "incorrect number of initializers for vector: got " + error(this) << "incorrect number of initializers for matrix: got " << num_args() << ", expected " << rows * cols << "\n"; } @@ -1444,11 +1442,17 @@ Type MatrixExpr::check_inverse(TypeSema& sema) const { auto mat = arg(0)->type().isa(); - if (!mat || mat->is_vector() || mat->rows() != mat->cols() || !sema.is_float(mat->elem_type())) { - error(this) << "invalid operand type for the inverse function\n"; + if (!mat || mat->is_vector() || mat->rows() != mat->cols()) { + error(this) << "invalid operand type for the inverse function: got " + << arg(0)->type() << ", expected square matrix\n"; return sema.type_error(); } + if (!sema.is_float(sema.scalar_type(mat->elem_type()))) { + error(this) << "the inverse operation is only available for floating point matrices\n"; + return sema.type_error(); + } + return mat; }