Skip to content
6 changes: 5 additions & 1 deletion src/impala/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ bool MapExpr::is_lvalue() const {
}

bool PrefixExpr::is_lvalue() const { return kind() == MUL; }
bool FieldExpr::is_lvalue() const { return lhs()->is_lvalue() || lhs()->type().isa<PtrType>(); }
bool FieldExpr::is_lvalue() const {
if (lhs()->type().isa<MatrixType>())
return (strlen(symbol().str()) == 1) && lhs()->is_lvalue();
return lhs()->is_lvalue() || lhs()->type().isa<PtrType>();
}
bool CastExpr::is_lvalue() const { return lhs()->is_lvalue(); }

//------------------------------------------------------------------------------
Expand Down
70 changes: 69 additions & 1 deletion src/impala/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,25 @@ class SimdASTType : public ArrayASTType {
friend class Parser;
};

class MatrixASTType : public ArrayASTType {
public:
uint32_t rows() const { return rows_; }
uint32_t cols() const { return cols_; }

bool is_vector() const { return cols_ == 1; }

virtual std::ostream& stream(std::ostream&) const override;
virtual void check(NameSema&) const override;
virtual void check(BorrowSema&) const override;

private:
virtual Type check(TypeSema&) const override;

thorin::u32 rows_, cols_;

friend class Parser;
};

//------------------------------------------------------------------------------

/*
Expand Down Expand Up @@ -1097,10 +1116,19 @@ class InfixExpr : public Expr {
virtual std::ostream& stream(std::ostream&) const override;
virtual Type check(TypeSema&, TypeExpectation) const override;

Type check_arith_op(TypeSema&) const;
const thorin::Def* emit(CodeGen&, TokenKind, const thorin::Def*, const thorin::Def*) const;

Kind kind_;
AutoPtr<const Expr> lhs_;
AutoPtr<const Expr> rhs_;

mutable enum VecOp {
SCALAR,
VECTOR,
MATRIX
} lvec_, rvec_;

friend class Parser;
};

Expand Down Expand Up @@ -1138,21 +1166,25 @@ class FieldExpr : public Expr {
const Identifier* identifier() const { return identifier_; }
Symbol symbol() const { return identifier()->symbol(); }
uint32_t index() const { return index_; }
const std::vector<uint32_t>& swizzle() const { return swizzle_; }
virtual bool is_lvalue() const override;
virtual void take_address() const override;
virtual void check(NameSema&) const override;
virtual void check(BorrowSema&) const override;
Type check_as_struct(TypeSema&, Type) const;

private:
virtual std::ostream& stream(std::ostream&) const override;
virtual Type check(TypeSema&, TypeExpectation) const override;
virtual thorin::Value lemit(CodeGen&) const override;
virtual const thorin::Def* remit(CodeGen&) const override;

Type check_as_struct(TypeSema&, Type) const;
Type check_as_matrix(TypeSema&, Type) const;

AutoPtr<const Expr> lhs_;
AutoPtr<const Identifier> identifier_;
mutable uint32_t index_ = uint32_t(-1);
mutable std::vector<uint32_t> swizzle_;

friend class Parser;
friend class MapExpr; // remove this
Expand Down Expand Up @@ -1334,6 +1366,42 @@ class MapExpr : public Expr, public Args, public TypeArgs {
friend class TypeSema;
};

class MatrixExpr : public Expr, public Args {
public:
enum Kind {
#define IMPALA_MAT_KEY(tok, str, r, c) tok = Token:: tok,
#include "tokenlist.h"
};

Kind kind() const { return kind_; }
virtual void check(NameSema&) const override;
virtual void check(BorrowSema&) const override;

private:
virtual std::ostream& stream(std::ostream&) const override;
virtual Type check(TypeSema&, TypeExpectation) const override;
virtual const thorin::Def* remit(CodeGen&) const override;

Type check_vector_args(TypeSema&, uint32_t) const;
Type check_matrix_args(TypeSema&, uint32_t, uint32_t) const;

Type check_inverse(TypeSema&) const;
Type check_determinant(TypeSema&) const;
Type check_cross(TypeSema&) const;
Type check_dot(TypeSema&) const;

const thorin::Def* emit_inverse(CodeGen&, const thorin::Def*) const;
const thorin::Def* emit_determinant(CodeGen&, const thorin::Def*) const;
const thorin::Def* emit_cross(CodeGen&, const thorin::Def*, const thorin::Def*) const;
const thorin::Def* emit_dot(CodeGen&, const thorin::Def*, const thorin::Def*) const;

friend class CodeGen;
friend class Parser;
friend class TypeSema;

Kind kind_;
};

class StmtLikeExpr : public Expr {};

class BlockExprBase : public StmtLikeExpr {
Expand Down
39 changes: 24 additions & 15 deletions src/impala/cgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,36 @@ class CGen {
}

static std::ostream& cgen_error(const ASTNode* node) {
return error(node) << "cannot generate C interface : ";
return error(node) << "cannot generate C interface: ";
}

// Generates a C type from an Impala type
static bool ctype_from_impala(const Type type, std::string& ctype_prefix, std::string& ctype_suffix) {
if (auto prim_type = type.isa<PrimType>()) {
switch (prim_type->primtype_kind()) {
case PrimType_i8:
ctype_prefix = "char"; ctype_suffix = "";
ctype_prefix = "int8_t"; ctype_suffix = "";
return true;
case PrimType_i16:
ctype_prefix = "short"; ctype_suffix = "";
ctype_prefix = "int16_t"; ctype_suffix = "";
return true;
case PrimType_i32:
ctype_prefix = "int"; ctype_suffix = "";
ctype_prefix = "int32_t"; ctype_suffix = "";
return true;
case PrimType_i64:
ctype_prefix = "long long"; ctype_suffix = "";
ctype_prefix = "int64_t"; ctype_suffix = "";
return true;
case PrimType_u8:
ctype_prefix = "unsigned char"; ctype_suffix = "";
ctype_prefix = "uint8_t"; ctype_suffix = "";
return true;
case PrimType_u16:
ctype_prefix = "unsigned short"; ctype_suffix = "";
ctype_prefix = "uint16_t"; ctype_suffix = "";
return true;
case PrimType_u32:
ctype_prefix = "unsigned int"; ctype_suffix = "";
ctype_prefix = "uint32_t"; ctype_suffix = "";
return true;
case PrimType_u64:
ctype_prefix = "unsigned long long"; ctype_suffix = "";
ctype_prefix = "uint64_t"; ctype_suffix = "";
return true;
case PrimType_f16:
ctype_prefix = "half"; ctype_suffix = "";
Expand Down Expand Up @@ -136,14 +136,15 @@ class CGen {
// &[T * N] -> T*
// &T -> T*

if (auto array_type = ptr_type->referenced_type().isa<ArrayType>()) {
if (!ctype_from_impala(array_type->elem_type(), ctype_prefix, ctype_suffix))
return false;
} else {
if (!ctype_from_impala(ptr_type->referenced_type(), ctype_prefix, ctype_suffix))
return false;
auto ref_type = ptr_type->referenced_type();
if (auto array_type = ref_type.isa<ArrayType>()) {
ref_type.clear();
ref_type = array_type->elem_type();
}

if (!ctype_from_impala(ref_type, ctype_prefix, ctype_suffix))
return false;

ctype_prefix += "*";
ctype_suffix = "";
return true;
Expand All @@ -163,6 +164,13 @@ class CGen {
return true;
}

if (auto matrix_type = type.isa<MatrixType>()) {
if (!ctype_from_impala(matrix_type->elem_type(), ctype_prefix, ctype_suffix))
return false;
ctype_suffix = "[" + std::to_string(matrix_type->size()) + "]";
return true;
}

return false;
}

Expand Down Expand Up @@ -347,6 +355,7 @@ bool generate_c_interface(const ModContents* mod, const CGenOptions& opts, std::
o << "/* " << opts.file_name << " : Impala interface file generated by impala */\n"
<< "#ifndef " << opts.guard << "\n"
<< "#define " << opts.guard << "\n\n"
<< "#include <stdint.h>\n\n"
<< "#ifdef __cplusplus\n"
<< "extern \"C\" {\n"
<< "#endif\n"
Expand Down
Loading