Skip to content
93 changes: 62 additions & 31 deletions src/ArbCall/ArbArgTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ Struct for conversion between C argument types in the Arb
documentation and Julia types.
"""
struct ArbArgTypes
supported::Dict{String,DataType}
supported::Dict{String,Union{DataType,UnionAll}}
unsupported::Set{String}
supported_reversed::Dict{DataType,String}
supported_reversed::Dict{Union{DataType,UnionAll},String}
end

function Base.getindex(arbargtypes::ArbArgTypes, key::AbstractString)
Expand All @@ -22,65 +22,96 @@ end

# Define the conversions we use for the rest of the code
const arbargtypes = ArbArgTypes(
Dict{String,DataType}(
Dict{String,Union{DataType,UnionAll}}(
# Primitive
"void" => Cvoid,
"void *" => Ptr{Cvoid},
"int" => Cint,
"slong" => Int,
"ulong" => UInt,
"double" => Cdouble,
"double *" => Vector{Float64},
"double" => Float64,
"complex_double" => ComplexF64,
"void *" => Ptr{Cvoid},
"char *" => Cstring,
"slong *" => Vector{Int},
"ulong *" => Vector{UInt},
"double *" => Vector{Float64},
"complex_double *" => Vector{ComplexF64},
# gmp.h
"mpz_t" => BigInt,
# mpfr.h
"mpfr_t" => BigFloat,
"mpfr_rnd_t" => Base.MPFR.MPFRRoundingMode,
# mag.h
"mag_t" => Mag,
# nfloat.h
"nfloat_ptr" => NFloat,
"nfloat_srcptr" => NFloat,
"gr_ctx_t" => nfloat_ctx_struct, # Actually in gr_types.h
# arf.h
"arf_t" => Arf,
"arf_rnd_t" => arb_rnd,
# acf.h
"acf_t" => Acf,
# arb.h
"arb_t" => Arb,
"acb_t" => Acb,
"mag_t" => Mag,
"arb_srcptr" => ArbVector,
"arb_ptr" => ArbVector,
"acb_srcptr" => AcbVector,
"arb_srcptr" => ArbVector,
# acb.h
"acb_t" => Acb,
"acb_ptr" => AcbVector,
"acb_srcptr" => AcbVector,
# arb_poly.h
"arb_poly_t" => ArbPoly,
# acb_poly.h
"acb_poly_t" => AcbPoly,
# arb_mat.h
"arb_mat_t" => ArbMatrix,
# acb_mat.h
"acb_mat_t" => AcbMatrix,
"arf_rnd_t" => arb_rnd,
"mpfr_t" => BigFloat,
"mpfr_rnd_t" => Base.MPFR.MPFRRoundingMode,
"mpz_t" => BigInt,
"char *" => Cstring,
"slong *" => Vector{Int},
"ulong *" => Vector{UInt},
),
Set(["FILE *", "fmpr_t", "fmpr_rnd_t", "flint_rand_t", "bool_mat_t"]),
Dict{DataType,String}(
Set(["FILE *", "flint_rand_t"]),
Dict{Union{DataType,UnionAll},String}(
# Primitive
Cvoid => "void",
Ptr{Cvoid} => "void *",
Cint => "int",
Int => "slong",
UInt => "ulong",
Cdouble => "double",
Vector{Float64} => "double *",
Float64 => "double",
ComplexF64 => "complex_double",
Ptr{Cvoid} => "void *",
Cstring => "char *",
Vector{Int} => "slong *",
Vector{UInt} => "ulong *",
Vector{Float64} => "double *",
Vector{ComplexF64} => "complex_double *",
# gmp.h
BigInt => "mpz_t",
# mpfr.h
BigFloat => "mpfr_t",
Base.MPFR.MPFRRoundingMode => "mpfr_rnd_t",
# mag.h
Mag => "mag_t",
# nfloat.h
NFloat => "nfloat_ptr",
nfloat_ctx_struct => "gr_ctx_t", # Actually in gr_types.h
# arf.h
Arf => "arf_t",
arb_rnd => "arf_rnd_t",
# acf.h
Acf => "acf_t",
# arb.h
Arb => "arb_t",
Acb => "acb_t",
Mag => "mag_t",
ArbVector => "arb_ptr",
# acb.h
Acb => "acb_t",
AcbVector => "acb_ptr",
# arb_poly.h
ArbPoly => "arb_poly_t",
# acb_poly.h
AcbPoly => "acb_poly_t",
# arb_mat.h
ArbMatrix => "arb_mat_t",
# acb_mat.h
AcbMatrix => "acb_mat_t",
arb_rnd => "arf_rnd_t",
BigFloat => "mpfr_t",
Base.MPFR.MPFRRoundingMode => "mpfr_rnd_t",
BigInt => "mpz_t",
Cstring => "char *",
Vector{Int} => "slong *",
Vector{UInt} => "ulong *",
),
)
4 changes: 4 additions & 0 deletions src/ArbCall/ArbCall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import ..Arblib:
cstructtype,
arb_rnd,
mag_struct,
nfloat_struct,
nfloat_ctx_struct,
_get_nfloat_ctx_struct,
arf_struct,
acf_struct,
arb_struct,
Expand All @@ -23,6 +26,7 @@ import ..Arblib:
arb_mat_struct,
acb_mat_struct,
MagLike,
NFloatLike,
ArfLike,
AcfLike,
ArbLike,
Expand Down
2 changes: 1 addition & 1 deletion src/ArbCall/ArbFPWrapFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function jlargs(af::ArbFPWrapFunction)
cargs[end] == Carg{Cint}(:flags, false) ||
throw(ArgumentError("expected last argument to be flags::Cint, got $(cargs[end])"))

args = [:($(name(carg))::$(jltype(carg))) for carg in cargs[n+1:end-1]]
args = [jlarg(carg) for carg in cargs[n+1:end-1]]

if basetype(af) == Float64
kwargs = [
Expand Down
52 changes: 37 additions & 15 deletions src/ArbCall/ArbFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,23 @@ is_series_method(af::ArbFunction) =
(jltype(first(arguments(af))) <: Union{Arblib.ArbPolyLike,Arblib.AcbPolyLike})

const jlfname_prefixes = (
"double",
"cdouble",
"mag",
"nfloat",
"ctx",
"arf",
"acf",
"arb",
"acb",
"mag",
"mat",
"vec",
"poly",
"scalar",
"mat",
"fpwrap",
"double",
"cdouble",
"scalar",
)
const jlfname_suffixes =
("si", "ui", "d", "mag", "arf", "acf", "arb", "acb", "mpz", "mpfr", "str")
("si", "ui", "d", "str", "mpz", "mpfr", "mag", "nfloat", "arf", "acf", "arb", "acb")

function jlfname(
arbfname::AbstractString;
Expand Down Expand Up @@ -117,15 +119,15 @@ jlfname_series(af::ArbFunction) = jlfname_series(arbfname(af))
function jlargs(af::ArbFunction; argument_detection::Bool = true)
cargs = arguments(af)

jl_arg_names_types = Tuple{Symbol,Any}[]
args = Expr[]
kwargs = Expr[]

prec_kwarg = false
rnd_kwarg = false
flags_kwarg = false
for (i, carg) in enumerate(cargs)
if !argument_detection
push!(jl_arg_names_types, (name(carg), jltype(carg)))
push!(args, jlarg(carg))
continue
end

Expand All @@ -146,13 +148,13 @@ function jlargs(af::ArbFunction; argument_detection::Bool = true)
push!(kwargs, extract_rounding_argument(carg))
elseif i > 1 && is_length_argument(carg, cargs[i-1])
push!(kwargs, extract_length_argument(carg, cargs[i-1]))
elseif is_ctx_argument(carg)
push!(kwargs, extract_ctx_argument(carg, first(cargs)))
else
push!(jl_arg_names_types, (name(carg), jltype(carg)))
push!(args, jlarg(carg))
end
end

args = [:($a::$T) for (a, T) in jl_arg_names_types]

return args, kwargs
end

Expand Down Expand Up @@ -199,13 +201,22 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))

returnT = returntype(af)
cargs = arguments(af)
where_type_parameters = unique(reduce(vcat, type_parameters.(cargs)))

func_full_args_call = :($jl_fname($(jl_full_args...)))

func_full_args_header = if isempty(where_type_parameters)
func_full_args_call
else
Expr(:where, func_full_args_call, where_type_parameters...)
end

func_full_args = :(
function $jl_fname($(jl_full_args...))
func_full_args_body = :(
begin
__ret = ccall(
Arblib.@libflint($(arbfname(af))),
$returnT,
$(Expr(:tuple, ctype.(cargs)...)),
$(Expr(:tuple, carg_expr.(cargs)...)),
$(name.(cargs)...),
)
$(
Expand All @@ -220,7 +231,11 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
end
)

func_full_args = Expr(:function, func_full_args_header, func_full_args_body)

if is_series_method(af)
@assert isempty(where_type_parameters) # Currently not supported for series methods

# Note that this currently doesn't respect any custom function
# name given as an argument.
jl_fname_series = jlfname_series(af)
Expand All @@ -242,9 +257,16 @@ function jlcode(af::ArbFunction, jl_fname = jlfname(af))
if isempty(jl_kwargs)
return code
else
func_kwarg_args_call = :($jl_fname($(jl_args...); $(jl_kwargs...)))
func_kwarg_args_header = if isempty(where_type_parameters)
func_kwarg_args_call
else
Expr(:where, func_kwarg_args_call, where_type_parameters...)
end

return quote
$code
$jl_fname($(jl_args...); $(jl_kwargs...)) = $jl_fname($(name.(cargs)...))
$func_kwarg_args_header = $jl_fname($(name.(cargs)...))
end
end
end
Expand Down
Loading
Loading