Skip to content

Commit fa9c284

Browse files
committed
Change parameterization of LogNormal and fix parsing bug
1 parent 0c81dde commit fa9c284

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

torchtree/cli/evolution.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,19 +1458,19 @@ def create_ucln_prior(branch_model_id):
14581458
mean = Parameter.json_factory(
14591459
f"{branch_model_id}.rates.prior.mean", **{"tensor": [0.001]}
14601460
)
1461-
scale = Parameter.json_factory(
1462-
f"{branch_model_id}.rates.prior.scale", **{"tensor": [1.0]}
1461+
stdev = Parameter.json_factory(
1462+
f"{branch_model_id}.rates.prior.stdev", **{"tensor": [0.1]}
14631463
)
14641464
mean[CONSTRAINT.LOWER.value] = 0.0
1465-
scale[CONSTRAINT.LOWER.value] = 0.0
1465+
stdev[CONSTRAINT.LOWER.value] = 0.0
14661466
joint_list.append(
14671467
Distribution.json_factory(
14681468
f"{branch_model_id}.rates.prior",
14691469
"LogNormal",
14701470
f"{branch_model_id}.rates",
14711471
{
14721472
"mean": mean,
1473-
"scale": scale,
1473+
"stdev": stdev,
14741474
},
14751475
)
14761476
)
@@ -1485,7 +1485,7 @@ def create_ucln_prior(branch_model_id):
14851485
Distribution.json_factory(
14861486
f"{branch_model_id}.rates.scale.prior",
14871487
"torch.distributions.Gamma",
1488-
f"{branch_model_id}.rates.prior.scale",
1488+
f"{branch_model_id}.rates.prior.stdev",
14891489
{
14901490
"concentration": 0.5396,
14911491
"rate": 2.6184,

torchtree/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ def process_object(data, dic):
181181
raise JSONParseError("Missing `id' and `type' keys") from None
182182

183183
if id_ in dic:
184-
raise JSONParseError("Object with ID `{id_}' already exists")
184+
raise JSONParseError(f"Object with ID `{id_}' already exists")
185185
if "type" not in data:
186-
raise JSONParseError("Object with ID `{id_}' does not have a type")
186+
raise JSONParseError(f"Object with ID `{id_}' does not have a type")
187187

188188
try:
189189
klass = get_class(data["type"])

torchtree/distributions/log_normal.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,26 @@ class LogNormal(torch.distributions.LogNormal):
1818
:attr:`scale`.
1919
2020
:param mean: mean of the distribution
21-
:param scale: standard deviation of log of the distribution
22-
:param scale_real: standard deviation of the distribution
21+
:param scale: scale (sigma) parameter of log of the distribution
22+
:param stdev: standard deviation of the distribution
2323
"""
2424

2525
def __init__(
2626
self,
2727
mean: Union[Tensor, float],
2828
scale: Union[Tensor, float, None] = None,
29-
scale_real: Union[Tensor, float, None] = None,
29+
stdev: Union[Tensor, float, None] = None,
3030
validate_args=None,
3131
) -> None:
32-
if (scale is not None) + (scale_real is not None) != 1:
33-
raise ValueError("Exactly one of scale or scale_real may be specified.")
32+
if (scale is not None) + (stdev is not None) != 1:
33+
raise ValueError("Exactly one of scale or stdev may be specified.")
3434

3535
if scale is not None:
3636
log_mean = log(mean) if isinstance(mean, Number) else mean.log()
3737
var = scale * scale if isinstance(scale, Number) else scale.pow(2)
3838
loc = log_mean - var * 0.5
3939
else:
40-
temp = 1.0 + scale_real * scale_real / (mean * mean)
40+
temp = 1.0 + stdev * stdev / (mean * mean)
4141
if isinstance(temp, Number):
4242
loc = math.log(mean / math.sqrt(temp))
4343
scale = math.sqrt(math.log(temp))

0 commit comments

Comments
 (0)