diff --git a/src/impala/ast.cpp b/src/impala/ast.cpp index c29723639..060cc6719 100644 --- a/src/impala/ast.cpp +++ b/src/impala/ast.cpp @@ -86,7 +86,11 @@ bool MapExpr::is_lvalue() const { } bool PrefixExpr::is_lvalue() const { return kind() == MUL; } -bool FieldExpr::is_lvalue() const { return lhs()->is_lvalue() || lhs()->type().isa(); } +bool FieldExpr::is_lvalue() const { + if (lhs()->type().isa()) + return (strlen(symbol().str()) == 1) && lhs()->is_lvalue(); + return lhs()->is_lvalue() || lhs()->type().isa(); +} bool CastExpr::is_lvalue() const { return lhs()->is_lvalue(); } //------------------------------------------------------------------------------ diff --git a/src/impala/ast.h b/src/impala/ast.h index b37b4c75e..279b53c69 100644 --- a/src/impala/ast.h +++ b/src/impala/ast.h @@ -422,6 +422,25 @@ class SimdASTType : public ArrayASTType { friend class Parser; }; +class MatrixASTType : public ArrayASTType { +public: + uint32_t rows() const { return rows_; } + uint32_t cols() const { return cols_; } + + bool is_vector() const { return cols_ == 1; } + + virtual std::ostream& stream(std::ostream&) const override; + virtual void check(NameSema&) const override; + virtual void check(BorrowSema&) const override; + +private: + virtual Type check(TypeSema&) const override; + + thorin::u32 rows_, cols_; + + friend class Parser; +}; + //------------------------------------------------------------------------------ /* @@ -1097,10 +1116,19 @@ class InfixExpr : public Expr { virtual std::ostream& stream(std::ostream&) const override; virtual Type check(TypeSema&, TypeExpectation) const override; + Type check_arith_op(TypeSema&) const; + const thorin::Def* emit(CodeGen&, TokenKind, const thorin::Def*, const thorin::Def*) const; + Kind kind_; AutoPtr lhs_; AutoPtr rhs_; + mutable enum VecOp { + SCALAR, + VECTOR, + MATRIX + } lvec_, rvec_; + friend class Parser; }; @@ -1138,11 +1166,11 @@ class FieldExpr : public Expr { const Identifier* identifier() const { return identifier_; } Symbol symbol() const { return identifier()->symbol(); } uint32_t index() const { return index_; } + const std::vector& swizzle() const { return swizzle_; } virtual bool is_lvalue() const override; virtual void take_address() const override; virtual void check(NameSema&) const override; virtual void check(BorrowSema&) const override; - Type check_as_struct(TypeSema&, Type) const; private: virtual std::ostream& stream(std::ostream&) const override; @@ -1150,9 +1178,13 @@ class FieldExpr : public Expr { virtual thorin::Value lemit(CodeGen&) const override; virtual const thorin::Def* remit(CodeGen&) const override; + Type check_as_struct(TypeSema&, Type) const; + Type check_as_matrix(TypeSema&, Type) const; + AutoPtr lhs_; AutoPtr identifier_; mutable uint32_t index_ = uint32_t(-1); + mutable std::vector swizzle_; friend class Parser; friend class MapExpr; // remove this @@ -1334,6 +1366,42 @@ class MapExpr : public Expr, public Args, public TypeArgs { friend class TypeSema; }; +class MatrixExpr : public Expr, public Args { +public: + enum Kind { +#define IMPALA_MAT_KEY(tok, str, r, c) tok = Token:: tok, +#include "tokenlist.h" + }; + + Kind kind() const { return kind_; } + virtual void check(NameSema&) const override; + virtual void check(BorrowSema&) const override; + +private: + virtual std::ostream& stream(std::ostream&) const override; + virtual Type check(TypeSema&, TypeExpectation) const override; + virtual const thorin::Def* remit(CodeGen&) const override; + + Type check_vector_args(TypeSema&, uint32_t) const; + Type check_matrix_args(TypeSema&, uint32_t, uint32_t) const; + + Type check_inverse(TypeSema&) const; + Type check_determinant(TypeSema&) const; + Type check_cross(TypeSema&) const; + Type check_dot(TypeSema&) const; + + const thorin::Def* emit_inverse(CodeGen&, const thorin::Def*) const; + const thorin::Def* emit_determinant(CodeGen&, const thorin::Def*) const; + const thorin::Def* emit_cross(CodeGen&, const thorin::Def*, const thorin::Def*) const; + const thorin::Def* emit_dot(CodeGen&, const thorin::Def*, const thorin::Def*) const; + + friend class CodeGen; + friend class Parser; + friend class TypeSema; + + Kind kind_; +}; + class StmtLikeExpr : public Expr {}; class BlockExprBase : public StmtLikeExpr { diff --git a/src/impala/cgen.cpp b/src/impala/cgen.cpp index 560c5538f..d9232ecc2 100644 --- a/src/impala/cgen.cpp +++ b/src/impala/cgen.cpp @@ -40,7 +40,7 @@ class CGen { } static std::ostream& cgen_error(const ASTNode* node) { - return error(node) << "cannot generate C interface : "; + return error(node) << "cannot generate C interface: "; } // Generates a C type from an Impala type @@ -48,28 +48,28 @@ class CGen { if (auto prim_type = type.isa()) { switch (prim_type->primtype_kind()) { case PrimType_i8: - ctype_prefix = "char"; ctype_suffix = ""; + ctype_prefix = "int8_t"; ctype_suffix = ""; return true; case PrimType_i16: - ctype_prefix = "short"; ctype_suffix = ""; + ctype_prefix = "int16_t"; ctype_suffix = ""; return true; case PrimType_i32: - ctype_prefix = "int"; ctype_suffix = ""; + ctype_prefix = "int32_t"; ctype_suffix = ""; return true; case PrimType_i64: - ctype_prefix = "long long"; ctype_suffix = ""; + ctype_prefix = "int64_t"; ctype_suffix = ""; return true; case PrimType_u8: - ctype_prefix = "unsigned char"; ctype_suffix = ""; + ctype_prefix = "uint8_t"; ctype_suffix = ""; return true; case PrimType_u16: - ctype_prefix = "unsigned short"; ctype_suffix = ""; + ctype_prefix = "uint16_t"; ctype_suffix = ""; return true; case PrimType_u32: - ctype_prefix = "unsigned int"; ctype_suffix = ""; + ctype_prefix = "uint32_t"; ctype_suffix = ""; return true; case PrimType_u64: - ctype_prefix = "unsigned long long"; ctype_suffix = ""; + ctype_prefix = "uint64_t"; ctype_suffix = ""; return true; case PrimType_f16: ctype_prefix = "half"; ctype_suffix = ""; @@ -136,14 +136,15 @@ class CGen { // &[T * N] -> T* // &T -> T* - if (auto array_type = ptr_type->referenced_type().isa()) { - if (!ctype_from_impala(array_type->elem_type(), ctype_prefix, ctype_suffix)) - return false; - } else { - if (!ctype_from_impala(ptr_type->referenced_type(), ctype_prefix, ctype_suffix)) - return false; + auto ref_type = ptr_type->referenced_type(); + if (auto array_type = ref_type.isa()) { + ref_type.clear(); + ref_type = array_type->elem_type(); } + if (!ctype_from_impala(ref_type, ctype_prefix, ctype_suffix)) + return false; + ctype_prefix += "*"; ctype_suffix = ""; return true; @@ -163,6 +164,13 @@ class CGen { return true; } + if (auto matrix_type = type.isa()) { + if (!ctype_from_impala(matrix_type->elem_type(), ctype_prefix, ctype_suffix)) + return false; + ctype_suffix = "[" + std::to_string(matrix_type->size()) + "]"; + return true; + } + return false; } @@ -347,6 +355,7 @@ bool generate_c_interface(const ModContents* mod, const CGenOptions& opts, std:: o << "/* " << opts.file_name << " : Impala interface file generated by impala */\n" << "#ifndef " << opts.guard << "\n" << "#define " << opts.guard << "\n\n" + << "#include \n\n" << "#ifdef __cplusplus\n" << "extern \"C\" {\n" << "#endif\n" diff --git a/src/impala/emit.cpp b/src/impala/emit.cpp index 9879a1ed4..b3bc0bde4 100644 --- a/src/impala/emit.cpp +++ b/src/impala/emit.cpp @@ -182,6 +182,10 @@ const thorin::Type* SimdTypeNode::convert(CodeGen& cg) const { return cg.world().type(scalar->as()->primtype_kind(), size()); } +const thorin::Type* MatrixTypeNode::convert(CodeGen& cg) const { + return cg.world().definite_array_type(cg.convert(elem_type()), size()); +} + /* * Decls and Function */ @@ -466,16 +470,13 @@ const Def* InfixExpr::remit(CodeGen& cg) const { return cg.converge(this, x); } default: - const TokenKind op = (TokenKind) kind(); - + auto op = (TokenKind)kind(); if (Token::is_assign(op)) { Value lvar = cg.lemit(lhs()); const Def* rdef = cg.remit(rhs()); - if (op != Token::ASGN) { - TokenKind sop = Token::separate_assign(op); - rdef = cg.world().binop(Token::to_binop(sop), lvar.load(loc()), rdef, loc()); - } + if (op != Token::ASGN) + rdef = emit(cg, Token::separate_assign(op), lvar.load(loc()), rdef); lvar.store(rdef, loc()); return cg.world().tuple({}, loc()); @@ -483,8 +484,55 @@ const Def* InfixExpr::remit(CodeGen& cg) const { const Def* ldef = cg.remit(lhs()); const Def* rdef = cg.remit(rhs()); - return cg.world().binop(Token::to_binop(op), ldef, rdef, loc()); + return emit(cg, op, ldef, rdef); + } +} + +const Def* InfixExpr::emit(CodeGen& cg, TokenKind op, const Def* ldef, const Def* rdef) const { + if (lvec_ == SCALAR && rvec_ == SCALAR) + return cg.world().binop(Token::to_binop(op), ldef, rdef, loc()); + + if (kind() == MUL && lvec_ != SCALAR && rvec_ != SCALAR) { + auto lmat = lhs()->type().as(); + auto rmat = rhs()->type().as(); + if (!lmat->is_vector() || !rmat->is_vector()) { + // matrix-vector or vector-matrix multiplication + bool transpose = lmat->is_vector(); + int rows = transpose ? lmat->cols() : lmat->rows(); + int cols = rmat->cols(); + int lrows = lmat->rows(); + int rrows = rmat->rows(); + Array defs(rows * cols); + for (int i = 0; i < cols; i++) { + for (int j = 0; j < rows; j++) { + const Def* sum = nullptr; + for (int k = 0, n = rmat->rows(); k < n; k++) { + auto mul = cg.world().arithop_mul( + cg.world().extract(ldef, transpose ? j : k * lrows + j, loc()), + cg.world().extract(rdef, i * rrows + k, loc()), loc()); + sum = sum ? cg.world().arithop_add(sum, mul, loc()) : mul; + } + defs[transpose ? i : i * rows + j] = sum; + } + } + return cg.world().definite_array(defs, loc()); + } } + + // component-wise ops and scalar vs. matrix/vector ops are handled here + typedef std::function ExtractFn; + auto lextract = lvec_ == SCALAR ? + ExtractFn([&] (int i) { return ldef; }) : + ExtractFn([&] (int i) { return cg.world().extract(ldef, i, loc()); }); + auto rextract = rvec_ == SCALAR ? + ExtractFn([&] (int i) { return rdef; }) : + ExtractFn([&] (int i) { return cg.world().extract(rdef, i, loc()); }); + + int n = (lvec_ != SCALAR ? lhs() : rhs())->type().as()->size(); + Array defs(n); + for (int i = 0; i < n; i++) + defs[i] = cg.world().binop(Token::to_binop(op), lextract(i), rextract(i), loc()); + return cg.world().definite_array(defs, loc()); } const Def* PostfixExpr::remit(CodeGen& cg) const { @@ -577,11 +625,167 @@ const Def* MapExpr::remit(CodeGen& cg, State state, Location eval_loc) const { THORIN_UNREACHABLE; } +const Def* MatrixExpr::remit(CodeGen& cg) const { + switch (kind()) { + case VEC2: + case VEC3: + case VEC4: + case MAT2X2: + case MAT3X3: + case MAT4X4: + case MAT2X3: + case MAT2X4: + case MAT3X2: + case MAT3X4: + case MAT4X2: + case MAT4X3: { + auto mat = type().as(); + int i = 0, n = mat->size(); + Array defs(n); + for (auto arg : args()) { + auto def = cg.remit(arg); + if (arg->type().isa()) { + for (int j = 0, m = def->size(); j < m; j++) + defs[i++] = cg.world().extract(def, j, loc()); + } else { + defs[i++] = def; + } + } + + if (i == 1) { + // repetition constructor, like so: vec4(1.0f) + if (mat->is_vector()) { + for (; i < n; i++) defs[i] = defs[0]; + } else { + // for matrices, this means defs[0] * identity + auto z = cg.world().zero(cg.convert(mat->elem_type()), loc()); + for (int i = 0, n = mat->rows(); i < n; i++) { + for (int j = 0, m = mat->cols(); j < m; j++) + defs[i * mat->cols() + j] = i == j ? defs[0] : z; + } + } + } + return cg.world().definite_array(defs, loc()); + } + case MAT_INVERSE: return emit_inverse(cg, cg.remit(arg(0))); break; + case MAT_DETERMINANT: return emit_determinant(cg, cg.remit(arg(0))); break; + case VEC_CROSS: return emit_cross(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; + case VEC_DOT: return emit_dot(cg, cg.remit(arg(0)), cg.remit(arg(1))); break; + default: break; + } + THORIN_UNREACHABLE; +} + +static const Def* submatrix(CodeGen& cg, int n, int row, int col, const Def* def, const Location& loc) { + // removes row and col from the given square matrix + Array defs((n - 1) * (n - 1)); + for (int i = 0, p = 0; i < n; i++) { + if (i == col) continue; + for (int j = 0, q = 0; j < n; j++) { + if (j == row) continue; + defs[p * (n - 1) + q] = cg.world().extract(def, i * n + j, loc); + q++; + } + p++; + } + return cg.world().definite_array(defs, loc); +} + +static const Def* determinant(CodeGen& cg, int n, const Def* def, const Location& loc) { + if (n == 1) return cg.world().extract(def, 0, loc); + if (n == 2) { + return cg.world().arithop_sub( + cg.world().arithop_mul(cg.world().extract(def, 0, loc), + cg.world().extract(def, 3, loc), loc), + cg.world().arithop_mul(cg.world().extract(def, 1, loc), + cg.world().extract(def, 2, loc), loc), loc); + } + + const Def* sum = nullptr; + for (int i = 0; i < n; i++) { + auto mul = cg.world().arithop_mul( + cg.world().extract(def, i, loc), + determinant(cg, n - 1, submatrix(cg, n, i, 0, def, loc), loc), loc); + sum = sum ? cg.world().binop(i % 2 ? ArithOp_add : ArithOp_sub, sum, mul, loc) : mul; + } + return sum; +} + +const Def* MatrixExpr::emit_determinant(CodeGen& cg, const Def* def) const { + return determinant(cg, arg(0)->type().as()->rows(), def, loc()); +} + +const Def* MatrixExpr::emit_inverse(CodeGen& cg, const Def* def) const { + // the formula used is A.t(com A) = (det A).Id + auto mat = arg(0)->type().as(); + int n = mat->rows(); + + // 1. compute (det A) + auto det = determinant(cg, n, def, loc()); + + // 2. compute t(com A) + Array elems(n * n); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + auto subdet = determinant(cg, n - 1, submatrix(cg, n, i, j, def, loc()), loc()); + elems[i * n + j] = (i + j) % 2 ? cg.world().arithop_minus(subdet, loc()) : subdet; + } + } + + // 3. compute det A == 0 ? 0 : 1/detA + auto elem = cg.convert(mat->elem_type()); + auto zero = cg.world().zero(elem, loc()); + auto one = cg.world().literal(pf32(1.0f), loc()); + auto cmp = cg.world().cmp_eq(det, zero, loc()); + auto inv = cg.world().select(cmp, zero, cg.world().arithop_div(one, det, loc()), loc()); + + for (int i = 0; i < n * n; i++) elems[i] = cg.world().arithop_mul(elems[i], inv, loc()); + + return cg.world().definite_array(elems, loc()); +} + +const Def* MatrixExpr::emit_cross(CodeGen& cg, const Def* ldef, const Def* rdef) const { + Array defs(3); + for (int i = 0; i < 3; i++) { + int j = (i + 1) % 3, k = (i + 2) % 3; + defs[i] = cg.world().arithop_sub( + cg.world().arithop_mul( + cg.world().extract(ldef, j, loc()), + cg.world().extract(rdef, k, loc()), loc()), + cg.world().arithop_mul( + cg.world().extract(ldef, k, loc()), + cg.world().extract(rdef, j, loc()), loc()), loc()); + } + return cg.world().definite_array(defs, loc()); +} + +const Def* MatrixExpr::emit_dot(CodeGen& cg, const Def* ldef, const Def* rdef) const { + int n = arg(0)->type().as()->rows(); + const Def* sum = nullptr; + for (int i = 0; i < n; i++) { + auto mul = cg.world().arithop_mul( + cg.world().extract(ldef, i, loc()), + cg.world().extract(rdef, i, loc()), loc()); + sum = sum ? cg.world().arithop_add(sum, mul, loc()) : mul; + } + return sum; +} + Value FieldExpr::lemit(CodeGen& cg) const { return Value::create_agg(cg.lemit(lhs()), cg.world().literal_qu32(index(), loc())); } const Def* FieldExpr::remit(CodeGen& cg) const { + if (lhs()->type().isa()) { + int i = 0, n = swizzle().size(); + auto def = cg.remit(lhs()); + Array defs(n); + for (auto s : swizzle()) { + defs[i++] = cg.world().extract(def, s, loc()); + } + return n > 1 ? cg.world().definite_array(defs, loc()) : defs[0]; + } + return cg.extract(cg.remit(lhs()), index(), loc()); } diff --git a/src/impala/parser.cpp b/src/impala/parser.cpp index c048a79a1..cf4c82fbe 100644 --- a/src/impala/parser.cpp +++ b/src/impala/parser.cpp @@ -69,7 +69,23 @@ case Token::L_BRACE: \ case Token::RUN_BLOCK: \ case Token::L_BRACKET: \ - case Token::SIMD + case Token::SIMD: \ + case Token::VEC2: \ + case Token::VEC3: \ + case Token::VEC4: \ + case Token::MAT2X2: \ + case Token::MAT3X3: \ + case Token::MAT4X4: \ + case Token::MAT2X3: \ + case Token::MAT2X4: \ + case Token::MAT3X2: \ + case Token::MAT3X4: \ + case Token::MAT4X2: \ + case Token::MAT4X3: \ + case Token::MAT_INVERSE: \ + case Token::MAT_DETERMINANT: \ + case Token::VEC_CROSS: \ + case Token::VEC_DOT #define STMT_NOT_EXPR \ Token::LET: \ @@ -190,16 +206,17 @@ class Parser { void parse_return_param(Fn* fn); // types - const ASTType* parse_type(); - const ArrayASTType* parse_array_type(); - const Typeof* parse_typeof(); - const ASTType* parse_return_type(bool&); - const FnASTType* parse_fn_type(); - const PrimASTType* parse_prim_type(); - const PtrASTType* parse_ptr_type(); - const TupleASTType* parse_tuple_type(); - const SimdASTType* parse_simd_type(); - const ASTTypeApp* parse_type_app(); + const ASTType* parse_type(); + const ArrayASTType* parse_array_type(); + const Typeof* parse_typeof(); + const ASTType* parse_return_type(bool&); + const FnASTType* parse_fn_type(); + const PrimASTType* parse_prim_type(); + const PtrASTType* parse_ptr_type(); + const TupleASTType* parse_tuple_type(); + const SimdASTType* parse_simd_type(); + const MatrixASTType* parse_matrix_type(); + const ASTTypeApp* parse_type_app(); enum class BodyMode { None, Optional, Mandatory }; @@ -229,6 +246,7 @@ class Parser { const Expr* parse_infix_expr(const Expr* lhs); const Expr* parse_postfix_expr(const Expr* lhs); const Expr* parse_primary_expr(); + const MatrixExpr* parse_matrix_expr(); const LiteralExpr* parse_literal_expr(); const CharExpr* parse_char_expr(); const StrExpr* parse_str_expr(); @@ -709,6 +727,21 @@ const ASTType* Parser::parse_type() { case Token::AND: case Token::ANDAND: return parse_ptr_type(); case Token::SIMD: return parse_simd_type(); + + case Token::VEC2: + case Token::VEC3: + case Token::VEC4: + case Token::MAT2X2: + case Token::MAT3X3: + case Token::MAT4X4: + case Token::MAT2X3: + case Token::MAT2X4: + case Token::MAT3X2: + case Token::MAT3X4: + case Token::MAT4X2: + case Token::MAT4X3: + return parse_matrix_type(); + default: { error("type", ""); auto error_type = new ErrorASTType(prev_loc()); @@ -855,6 +888,21 @@ const SimdASTType* Parser::parse_simd_type() { return simd; } +const MatrixASTType* Parser::parse_matrix_type() { + auto mat = loc(new MatrixASTType()); + switch (la()) { +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok : mat->rows_ = r; mat->cols_ = c; break; +#include "impala/tokenlist.h" + default: THORIN_UNREACHABLE; + } + assert(mat->rows_ > 0 && mat->cols_ > 0); + lex(); + expect(Token::L_BRACKET, "vector or matrix type"); + mat->elem_type_ = parse_type(); + expect(Token::R_BRACKET, "vector or matrix type"); + return mat; +} + /* * expressions */ @@ -1008,6 +1056,10 @@ const Expr* Parser::parse_primary_expr() { parse_comma_list("elements of a simd expression", Token::R_BRACKET, [&] { simd->args_.push_back(parse_expr()); }); return simd; } +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok: +#include "impala/tokenlist.h" + return parse_matrix_expr(); + #define IMPALA_LIT(itype, atype) \ case Token::LIT_##itype: #include "impala/tokenlist.h" @@ -1066,6 +1118,19 @@ const Expr* Parser::parse_primary_expr() { } } +const MatrixExpr* Parser::parse_matrix_expr() { + auto vec = loc(new MatrixExpr()); + switch (la()) { +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok: vec->kind_ = MatrixExpr:: tok; break; +#include "impala/tokenlist.h" + default: THORIN_UNREACHABLE; + } + lex(); + expect(Token::L_PAREN, "vector expression"); + parse_comma_list("elements of a vector expression", Token::R_PAREN, [&] { vec->args_.push_back(parse_expr()); }); + return vec; +} + const LiteralExpr* Parser::parse_literal_expr() { LiteralExpr::Kind kind; Box box; diff --git a/src/impala/sema/borrowsema.cpp b/src/impala/sema/borrowsema.cpp index e79e8a063..c84bb4c27 100644 --- a/src/impala/sema/borrowsema.cpp +++ b/src/impala/sema/borrowsema.cpp @@ -76,6 +76,10 @@ void SimdASTType::check(BorrowSema& sema) const { elem_type()->check(sema); } +void MatrixASTType::check(BorrowSema& sema) const { + elem_type()->check(sema); +} + //------------------------------------------------------------------------------ @@ -243,6 +247,11 @@ void MapExpr::check(BorrowSema& sema) const { arg->check(sema); } +void MatrixExpr::check(BorrowSema& sema) const { + for (auto arg : args()) + arg->check(sema); +} + void IfExpr::check(BorrowSema& sema) const { cond()->check(sema); then_expr()->check(sema); diff --git a/src/impala/sema/namesema.cpp b/src/impala/sema/namesema.cpp index 69b0d435d..c68c81f86 100644 --- a/src/impala/sema/namesema.cpp +++ b/src/impala/sema/namesema.cpp @@ -169,6 +169,10 @@ void SimdASTType::check(NameSema& sema) const { elem_type()->check(sema); } +void MatrixASTType::check(NameSema& sema) const { + elem_type()->check(sema); +} + //------------------------------------------------------------------------------ /* @@ -368,6 +372,11 @@ void MapExpr::check(NameSema& sema) const { arg->check(sema); } +void MatrixExpr::check(NameSema& sema) const { + for (auto arg : args()) + arg->check(sema); +} + void IfExpr::check(NameSema& sema) const { cond()->check(sema); then_expr()->check(sema); diff --git a/src/impala/sema/typesema.cpp b/src/impala/sema/typesema.cpp index 267aa300e..f1e7d3695 100644 --- a/src/impala/sema/typesema.cpp +++ b/src/impala/sema/typesema.cpp @@ -60,11 +60,12 @@ class TypeSema : public TypeTable { } return true; } - Type scalar_type(const Expr*); + Type scalar_type(Type); bool is_int(Type t); bool is_float(Type t); bool expect_int(const Expr*); bool expect_int_or_bool(const Expr*); + bool expect_num(Type, const Expr*); void expect_num(const Expr*); Type expect_type(const Expr* expr, Type found, TypeExpectation expected); Type expect_type(const Expr* expr, TypeExpectation expected) { return expect_type(expr, expr->type(), expected); } @@ -135,8 +136,7 @@ class TypeSema : public TypeTable { //------------------------------------------------------------------------------ -Type TypeSema::scalar_type(const Expr* e) { - Type t = e->type(); +Type TypeSema::scalar_type(Type t) { if (auto simd = t.isa()) { return simd->elem_type(); } @@ -153,8 +153,7 @@ bool TypeSema::is_float(Type t) { } bool TypeSema::expect_int(const Expr* expr) { - auto t = scalar_type(expr); - + auto t = scalar_type(expr->type()); if (!t->is_error() && !is_int(t)) { error(expr) << "expected integer type but found " << t << "\n"; return false; @@ -163,8 +162,7 @@ bool TypeSema::expect_int(const Expr* expr) { } bool TypeSema::expect_int_or_bool(const Expr* expr) { - auto t = scalar_type(expr); - + auto t = scalar_type(expr->type()); if (!t->is_error() && !t->is_bool() && !is_int(t)) { error(expr) << "expected integer or boolean type but found " << t << "\n"; return false; @@ -172,11 +170,17 @@ bool TypeSema::expect_int_or_bool(const Expr* expr) { return true; } -void TypeSema::expect_num(const Expr* expr) { - auto t = scalar_type(expr); - - if (!t->is_error() && !t->is_bool() && !is_int(t) && !is_float(t)) +bool TypeSema::expect_num(Type t, const Expr* expr) { + auto scalar = scalar_type(t); + if (!scalar->is_error() && !scalar->is_bool() && !is_int(scalar) && !is_float(scalar)) { error(expr) << "expected number type but found " << t << "\n"; + return false; + } + return true; +} + +void TypeSema::expect_num(const Expr* expr) { + expect_num(expr->type(), expr); } Type TypeSema::expect_type(const Expr* expr, Type found_type, TypeExpectation expected) { @@ -365,6 +369,19 @@ Type SimdASTType::check(TypeSema& sema) const { } } +Type MatrixASTType::check(TypeSema& sema) const { + auto type = sema.check(elem_type()); + auto scalar = sema.scalar_type(type); + + if (!sema.is_int(scalar) && !sema.is_float(scalar)) { + error(this) << "only floating point and integer types are supported for " + << (is_vector() ? "vectors" : "matrices") << "\n"; + return sema.type_error(); + } + + return sema.matrix_type(type, rows(), cols()); +} + //------------------------------------------------------------------------------ Type ValueDecl::check(TypeSema& sema) const { return check(sema, Type()); } @@ -830,13 +847,10 @@ Type InfixExpr::check(TypeSema& sema, TypeExpectation expected) const { return sema.type_bool(); case ADD: case SUB: - case MUL: case DIV: + case MUL: case REM: { - auto type = sema.check(lhs(), sema.check(rhs(), expected)); - sema.expect_num(lhs()); - sema.expect_num(rhs()); - return type; + return check_arith_op(sema); } case SHL: case SHR: { @@ -863,12 +877,9 @@ Type InfixExpr::check(TypeSema& sema, TypeExpectation expected) const { case MUL_ASGN: case DIV_ASGN: case REM_ASGN: { - sema.check(rhs(), sema.check(lhs())); - if (sema.expect_lvalue(lhs())) { - sema.expect_num(lhs()); - sema.expect_num(rhs()); + check_arith_op(sema); + if (sema.expect_lvalue(lhs())) return sema.unit(); - } break; } case AND_ASGN: @@ -891,6 +902,65 @@ Type InfixExpr::check(TypeSema& sema, TypeExpectation expected) const { return sema.type_error(); } +inline Type matrix_elem_type(Type t) { + if (auto mat = t.isa()) + return mat->elem_type(); + return t; +} + +Type InfixExpr::check_arith_op(TypeSema& sema) const { + auto ltype = sema.check(lhs()); + auto rtype = sema.check(rhs()); + + if (ltype->is_error() || rtype->is_error()) return sema.type_error(); + + auto lelem = matrix_elem_type(ltype); + auto relem = matrix_elem_type(rtype); + bool lscalar = ltype == lelem; + bool rscalar = rtype == relem; + + lvec_ = lscalar ? SCALAR : (ltype.as()->is_vector() ? VECTOR : MATRIX); + rvec_ = rscalar ? SCALAR : (ltype.as()->is_vector() ? VECTOR : MATRIX); + + if (kind() == REM) { + // For REM, the types must be integers + if (!sema.is_int(sema.scalar_type(lelem)) || !sema.is_int(sema.scalar_type(relem))) { + error(this) << "the modulus '%' operator is only valid on integers\n"; + return sema.type_error(); + } + } else { + sema.expect_num(lelem, lhs()); + sema.expect_num(relem, rhs()); + } + + // if operands are both scalars or both vectors, types must be equal + if ((lscalar & rscalar) && ltype == rtype) return ltype; + if ((lscalar ^ rscalar) && lelem == relem) { return lscalar ? rtype : ltype; } + + if (!lscalar && !rscalar && lelem == relem) { + auto lmat = ltype.as(); + auto rmat = rtype.as(); + + if (kind() == MUL) { + // vector * matrix, matrix * vector, or matrix * matrix multiplication + if (!lmat->is_vector()) { + if (lmat->cols() == rmat->rows()) return rmat; + } else { + if (lmat->rows() == rmat->rows()) return lmat; + } + } else if (lmat->rows() == rmat->rows()) { + // addition, subtraction, division, ... + return lmat; + } + } + + if (!ltype->is_error() && !rtype->is_error()) { + error(this) << "types do not match for operator '" << Token(loc(), (Token::Kind)kind()) + << "', got " << ltype << " and " << rtype << " \n"; + } + return sema.type_error(); +} + Type PostfixExpr::check(TypeSema& sema, TypeExpectation expected) const { // TODO check if operator supports the type sema.check(lhs(), expected); @@ -1098,6 +1168,10 @@ Type TypeSema::check_call(const MapExpr* expr, FnType fn_poly, const ASTTypes& t } Type FieldExpr::check(TypeSema& sema, TypeExpectation expected) const { + auto ltype = sema.check(lhs()); + if (ltype.isa()) + return check_as_matrix(sema, expected.type()); + if (auto type = check_as_struct(sema, expected.type())) return type; @@ -1114,6 +1188,9 @@ Type FieldExpr::check_as_struct(TypeSema& sema, Type expected) const { ltype = sema.check(lhs()); } + if (auto mat = ltype.isa()) + return check_as_matrix(sema, expected); + if (auto struct_app = ltype.isa()) { if (auto field_decl = struct_app->struct_abs_type()->struct_decl()->field_decl(symbol())) { index_ = field_decl->index(); @@ -1126,6 +1203,47 @@ Type FieldExpr::check_as_struct(TypeSema& sema, Type expected) const { return Type(); } +Type FieldExpr::check_as_matrix(TypeSema& sema, Type expected) const { + auto ltype = sema.check(lhs()).as(); + if (ltype->is_vector()) { + const char* str = symbol().str(); + uint32_t len_dst = strlen(str); + if (len_dst > 4) { + error(this) << "too many components in swizzle operation\n"; + return sema.type_error(); + } + + uint32_t len_src = 0; + for (auto* p = str; *p; p++) { + uint32_t l; + switch (*p) { + case 'x': l = 0; break; + case 'y': l = 1; break; + case 'z': l = 2; break; + case 'w': l = 3; break; + default: + error(this) << "incorrect character in swizzle operation, only 'x', 'y', 'z', and 'w' are allowed\n"; + return sema.type_error(); + } + swizzle_.push_back(l); + len_src = std::max(len_src, l); + } + + if (len_src >= ltype->rows()) { + const char xyzw[] = "xyzw"; + error(this) << "vector component '" << xyzw[len_src] << "' is out of bounds\n"; + return sema.type_error(); + } + + if (len_dst == 1) index_ = len_src; + + return len_dst > 1 ? static_cast(sema.matrix_type(ltype->elem_type(), len_dst)) : ltype->elem_type(); + } + + error(this) << "matrix components cannot be accessed by field names, use the array syntax instead\n"; + return sema.type_error(); +} + Type MapExpr::check(TypeSema& sema, TypeExpectation expected) const { if (auto field_expr = lhs()->isa()) { if (field_expr->check_as_struct(sema, sema.unknown_type())) @@ -1194,6 +1312,200 @@ Type MapExpr::check_as_method_call(TypeSema& sema, TypeExpectation expected) con return sema.type_error(); } +Type MatrixExpr::check(TypeSema& sema, TypeExpectation expected) const { + if (!num_args()) + error(this) << "arguments expected\n"; + + for (auto arg : args()) { + // prevent error message explosion + if (sema.check(arg)->is_error()) return sema.type_error(); + } + + uint32_t rows, cols; + switch (kind()) { +#define IMPALA_MAT_KEY(tok, str, r, c) case Token:: tok : rows = r; cols = c; break; +#include "impala/tokenlist.h" + default: THORIN_UNREACHABLE; + } + + if (rows > 0) { + if (cols > 1) { return check_matrix_args(sema, rows, cols); } + else { return check_vector_args(sema, rows); } + } else { + // special functions (e.g. dot, cross, ...) have rows = cols = 0 + switch (kind()) { + case Token::MAT_INVERSE: return check_inverse(sema); + case Token::MAT_DETERMINANT: return check_determinant(sema); + case Token::VEC_DOT: return check_dot(sema); + case Token::VEC_CROSS: return check_cross(sema); + default: break; + } + } + + THORIN_UNREACHABLE; +} + +Type MatrixExpr::check_vector_args(TypeSema& sema, uint32_t rows) const { + // vectors can be constructed by concatenating scalars with smaller vectors + auto type = arg(0)->type(); + + Type elem_type; + if (auto mat = type.isa()) + elem_type = mat->elem_type(); + else + elem_type = type; + + auto scalar = sema.scalar_type(elem_type); + if (!sema.is_int(scalar) && !sema.is_float(scalar)) { + error(this) << "only floating point and integer types are supported for vectors\n"; + return sema.type_error(); + } + + uint32_t count = 0; + for (auto arg : args()) { + if (arg->type().isa() || arg->type().isa()) { + if (arg->type() != elem_type) { + error(this) << "mismatching types in vector initializers\n"; + return sema.type_error(); + } + count += 1; + } else if (auto mat = arg->type().isa()) { + if (mat->elem_type() != elem_type) { + error(this) << "mismatching types in vector initializers: got " + << mat->elem_type() << ", expected " << elem_type << "\n"; + return sema.type_error(); + } + + if (!mat->is_vector()) { + error(this) << "matrices are not allowed in vector initializers\n"; + return sema.type_error(); + } + count += mat->rows(); + } else { + error(this) << "incorrect type for vector initializer\n"; + return sema.type_error(); + } + } + + if (count != rows && count > 1) { + error(this) << "incorrect number of initializers for vector: got " + << count << ", expected " << rows << "\n"; + return sema.type_error(); + } + + return sema.matrix_type(elem_type, rows); +} + +Type MatrixExpr::check_matrix_args(TypeSema& sema, uint32_t rows, uint32_t cols) const { + // matrices can be constructed by providing the columns as vectors, or by listing all the components one by one. + auto type = arg(0)->type(); + for (auto arg : args()) { + if (arg->type() != type) { + error(this) << "mismatching types in matrix initializers\n"; + return sema.type_error(); + } + } + + if (auto mat = type.isa()) { + if (!mat->is_vector()) { + error(this) << "matrices cannot be initialized with matrices\n"; + return sema.type_error(); + } + + if (num_args() != cols || mat->rows() != rows) { + error(this) << "matrix initializer sizes do not match matrix dimensions\n"; + return sema.type_error(); + } + + return sema.matrix_type(mat->elem_type(), rows, cols); + } + + auto scalar = sema.scalar_type(type); + if (!sema.is_int(scalar) && !sema.is_float(scalar)) { + error(this) << "only floating point and integer types are supported for matrices\n"; + return sema.type_error(); + } + + if (num_args() > 1 && num_args() != rows * cols) { + error(this) << "incorrect number of initializers for matrix: got " + << num_args() << ", expected " << rows * cols << "\n"; + } + + return sema.matrix_type(type, rows, cols); +} + +Type MatrixExpr::check_inverse(TypeSema& sema) const { + if (num_args() != 1) { + error(this) << "incorrect number of arguments for the inverse function\n"; + return sema.type_error(); + } + + auto mat = arg(0)->type().isa(); + + if (!mat || mat->is_vector() || mat->rows() != mat->cols()) { + error(this) << "invalid operand type for the inverse function: got " + << arg(0)->type() << ", expected square matrix\n"; + return sema.type_error(); + } + + if (!sema.is_float(sema.scalar_type(mat->elem_type()))) { + error(this) << "the inverse operation is only available for floating point matrices\n"; + return sema.type_error(); + } + + return mat; +} + +Type MatrixExpr::check_determinant(TypeSema& sema) const { + if (num_args() != 1) { + error(this) << "incorrect number of arguments for the determinant function\n"; + return sema.type_error(); + } + + auto mat = arg(0)->type().isa(); + + if (!mat || mat->is_vector() || mat->rows() != mat->cols()) { + error(this) << "invalid operand type for the determinant function\n"; + return sema.type_error(); + } + + return mat->elem_type(); +} + +Type MatrixExpr::check_cross(TypeSema& sema) const { + if (num_args() != 2) { + error(this) << "incorrect number of arguments for the cross function\n"; + return sema.type_error(); + } + + auto lmat = arg(0)->type().isa(); + auto rmat = arg(1)->type().isa(); + + if (lmat != rmat || !lmat || !lmat->is_vector() || lmat->rows() != 3) { + error(this) << "invalid operand types for the cross function\n"; + return sema.type_error(); + } + + return sema.matrix_type(lmat->elem_type(), 3); +} + +Type MatrixExpr::check_dot(TypeSema& sema) const { + if (num_args() != 2) { + error(this) << "incorrect number of arguments for the dot function\n"; + return sema.type_error(); + } + + auto lmat = arg(0)->type().isa(); + auto rmat = arg(1)->type().isa(); + + if (lmat != rmat || !lmat || !lmat->is_vector()) { + error(this) << "invalid operand types for the dot function\n"; + return sema.type_error(); + } + + return lmat->elem_type(); +} + Type BlockExprBase::check(TypeSema& sema, TypeExpectation expected) const { THORIN_PUSH(sema.cur_block_, this); for (auto stmt : stmts()) diff --git a/src/impala/sema/typetable.h b/src/impala/sema/typetable.h index 1c4bb2efa..9e596a36e 100644 --- a/src/impala/sema/typetable.h +++ b/src/impala/sema/typetable.h @@ -32,6 +32,9 @@ class TypeTable { SimdType simd_type(Type elem_type, uint64_t size) { return join(new SimdTypeNode(*this, elem_type, size)); } + MatrixType matrix_type(Type elem_type, uint32_t rows, uint32_t cols = 1) { + return join(new MatrixTypeNode(*this, elem_type, rows, cols)); + } MutPtrType mut_ptr_type(Type referenced_type, int addr_space = 0) { return join(new MutPtrTypeNode(*this, referenced_type, addr_space)); } diff --git a/src/impala/sema/unifiable.cpp b/src/impala/sema/unifiable.cpp index 905029ecf..5913c5a49 100644 --- a/src/impala/sema/unifiable.cpp +++ b/src/impala/sema/unifiable.cpp @@ -306,9 +306,17 @@ bool TraitAppNode::equal(const Unifiable* other) const { bool SimdTypeNode::equal(const Unifiable* other) const { assert(this->is_unified()); - return Unifiable::equal(other) && (this->size() == other->as()->size()); + return Unifiable::equal(other) && (size() == other->as()->size()); } +bool MatrixTypeNode::equal(const Unifiable* other) const { + assert(this->is_unified()); + return Unifiable::equal(other) && + rows() == other->as()->rows() && + cols() == other->as()->cols(); +} + + //------------------------------------------------------------------------------ /* @@ -332,10 +340,8 @@ bool DefiniteArrayTypeNode::is_subtype(const TypeNode* other) const { return dim_eq && other->isa() && elem_type()->is_subtype(*other->as()->elem_type()); } - -bool SimdTypeNode::is_subtype(const TypeNode* other) const { - return this->equal(other); -} +bool SimdTypeNode::is_subtype(const TypeNode* other) const { return equal(other); } +bool MatrixTypeNode::is_subtype(const TypeNode* other) const { return equal(other); } /* * TODO merge this code with equal @@ -475,6 +481,10 @@ Type SimdTypeNode::vinstantiate(SpecializeMap& map) const { return map[this] = *typetable().simd_type(elem_type()->specialize(map), size()); } +Type MatrixTypeNode::vinstantiate(SpecializeMap& map) const { + return map[this] = *typetable().matrix_type(elem_type()->specialize(map), rows(), cols()); +} + Type FnTypeNode::vinstantiate(SpecializeMap& map) const { return map[this] = *typetable().fn_type(specialize_args(map)); } diff --git a/src/impala/sema/unifiable.h b/src/impala/sema/unifiable.h index 76a8d7ee5..f345a25a0 100644 --- a/src/impala/sema/unifiable.h +++ b/src/impala/sema/unifiable.h @@ -42,6 +42,7 @@ class OwnedPtrTypeNode; typedef Proxy OwnedPtr class PrimTypeNode; typedef Proxy PrimType; class PtrTypeNode; typedef Proxy PtrType; class SimdTypeNode; typedef Proxy SimdType; +class MatrixTypeNode; typedef Proxy MatrixType; class StructAbsTypeNode; typedef Proxy StructAbsType; class StructAppTypeNode; typedef Proxy StructAppType; class TraitAbsNode; typedef Proxy TraitAbs; @@ -162,6 +163,7 @@ enum Kind { Kind_noret, Kind_owned_ptr, Kind_simd, + Kind_matrix, Kind_struct_abs, Kind_struct_app, Kind_trait_abs, @@ -661,6 +663,30 @@ class SimdTypeNode : public ArrayTypeNode { const uint64_t size_; }; +class MatrixTypeNode : public ArrayTypeNode { +public: + MatrixTypeNode(TypeTable& tt, Type elem_type, uint32_t rows, uint32_t cols) + : ArrayTypeNode(tt, Kind_matrix, { elem_type }), rows_(rows), cols_(cols) + {} + + uint32_t rows() const { return rows_; } + uint32_t cols() const { return cols_; } + uint32_t size() const { return rows_ * cols_; } + + bool is_vector() const { return cols_ == 1; } + + virtual std::ostream& stream(std::ostream&) const override; + virtual bool is_subtype(const TypeNode*) const override; + virtual bool equal(const Unifiable*) const override; + +private: + virtual Type vinstantiate(SpecializeMap&) const override; + virtual const thorin::Type* convert(CodeGen&) const override; + + uint32_t rows_; + uint32_t cols_; +}; + /** * Represents a declared trait. * A trait consists of a name, a number of declared methods and a number of diff --git a/src/impala/stream.cpp b/src/impala/stream.cpp index 0331c1ea0..cbe5a08e1 100644 --- a/src/impala/stream.cpp +++ b/src/impala/stream.cpp @@ -80,6 +80,12 @@ std::ostream& IndefiniteArrayTypeNode::stream(std::ostream& os) const { return s std::ostream& SimdTypeNode::stream(std::ostream& os) const { return streamf(os, "simd[% * %]", elem_type(), size()); } std::ostream& StructAbsTypeNode::stream(std::ostream& os) const { return os << struct_decl_->symbol(); } +std::ostream& MatrixTypeNode::stream(std::ostream& os) const { + return is_vector() ? + streamf(os, "vec%[%]", rows(), elem_type()) : + streamf(os, "mat%x%[%]", rows(), cols(), elem_type()); +} + std::ostream& StructAppTypeNode::stream(std::ostream& os) const { os << struct_abs_type()->struct_decl()->symbol(); if (num_args() != 0) @@ -115,6 +121,11 @@ std::ostream& PtrASTType::stream(std::ostream& os) const { std::ostream& DefiniteArrayASTType::stream(std::ostream& os) const { return streamf(os, "[% * %]", elem_type(), dim()); } std::ostream& IndefiniteArrayASTType::stream(std::ostream& os) const { return streamf(os, "[%]", elem_type()); } std::ostream& SimdASTType::stream(std::ostream& os) const { return streamf(os, "simd[% * %]", elem_type(), size()); } +std::ostream& MatrixASTType::stream(std::ostream& os) const { + return is_vector() ? + streamf(os, "vec%[%]", rows(), elem_type()) : + streamf(os, "mat%x%[%]", rows(), cols(), elem_type()); +} std::ostream& TupleASTType::stream(std::ostream& os) const { return stream_list(os, args(), [&](const ASTType* type) { os << type; }, "(", ")"); @@ -381,6 +392,15 @@ std::ostream& SimdExpr::stream(std::ostream& os) const { return stream_list(os, args(), [&](const Expr* expr) { os << expr; }, "simd[", "]"); } +std::ostream& MatrixExpr::stream(std::ostream& os) const { + switch (kind()) { +#define IMPALA_KEY_VEC(tok, str) case tok : os << str; break; +#include "impala/tokenlist.h" + default: THORIN_UNREACHABLE; + } + return stream_list(os, args(), [&](const Expr* expr) { os << expr; }, "(", ")"); +} + std::ostream& PrefixExpr::stream(std::ostream& os) const { Prec r = PrecTable::prefix_r[kind()]; Prec old = prec; diff --git a/src/impala/tokenlist.h b/src/impala/tokenlist.h index f40e4f7ba..ec303496e 100644 --- a/src/impala/tokenlist.h +++ b/src/impala/tokenlist.h @@ -104,6 +104,29 @@ IMPALA_KEY(TYPEOF, "typeof") IMPALA_KEY(WHILE, "while") IMPALA_KEY(SIMD, "simd") +#ifndef IMPALA_MAT_KEY +#define IMPALA_MAT_KEY(tok, str, r, c) IMPALA_KEY(tok, str) +#endif + +IMPALA_MAT_KEY(VEC2, "vec2", 2, 1) +IMPALA_MAT_KEY(VEC3, "vec3", 3, 1) +IMPALA_MAT_KEY(VEC4, "vec4", 4, 1) +IMPALA_MAT_KEY(MAT2X2, "mat2x2", 2, 2) +IMPALA_MAT_KEY(MAT3X3, "mat3x3", 3, 3) +IMPALA_MAT_KEY(MAT4X4, "mat4x4", 4, 4) +IMPALA_MAT_KEY(MAT2X3, "mat2x3", 2, 3) +IMPALA_MAT_KEY(MAT2X4, "mat2x4", 2, 4) +IMPALA_MAT_KEY(MAT3X2, "mat3x2", 3, 2) +IMPALA_MAT_KEY(MAT3X4, "mat3x4", 3, 4) +IMPALA_MAT_KEY(MAT4X2, "mat4x2", 4, 2) +IMPALA_MAT_KEY(MAT4X3, "mat4x3", 4, 3) + +IMPALA_MAT_KEY(MAT_INVERSE, "inverse", 0, 0) +IMPALA_MAT_KEY(MAT_DETERMINANT, "determinant", 0, 0) +IMPALA_MAT_KEY(VEC_DOT, "dot", 0, 0) +IMPALA_MAT_KEY(VEC_CROSS, "cross", 0, 0) + +#undef IMPALA_MAT_KEY #undef IMPALA_KEY #ifndef IMPALA_MISC