-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnum_narratives.py
More file actions
103 lines (84 loc) · 2.78 KB
/
num_narratives.py
File metadata and controls
103 lines (84 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
from src.number_narratives import assess_narratives
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run number-of-narratives generation for coherence analysis"
)
parser.add_argument(
'--worker_model_name',
type=str,
default='unsloth_qwen_0.5B',
help='Model name to generate narratives (default: unsloth_qwen_0.5B)'
)
parser.add_argument(
'--max_model_len',
type=int,
default=8192,
help='Maximum context length for the model (default: 8192)'
)
parser.add_argument(
'--temperature',
type=float,
default=0.6,
help='Temperature for text generation (default: 0.6)'
)
parser.add_argument(
'--top_p',
type=float,
default=0.8,
help='Top-p sampling parameter (default: 0.8)'
)
parser.add_argument(
'--max_tokens',
type=int,
default=4096,
help='Maximum number of tokens to generate (default: 4096)'
)
parser.add_argument(
'--repetition_penalty',
type=float,
default=1.05,
help='Penalty applied to repeating tokens (default: 1.05)'
)
parser.add_argument(
'--dataset',
type=str,
default='adult',
choices=['adult', 'titanic', 'california', 'diabetes'],
help='Dataset to use for experiments (default: adult)'
)
parser.add_argument(
'--num_narratives',
type=int,
default=8,
help='Number of narratives (K) to generate per sample (default: 8)'
)
return parser.parse_args()
def display_config(args: argparse.Namespace) -> None:
print("\n========== NUMBER OF NARRATIVES GENERATION ==========")
print(f"Model name: {args.worker_model_name}")
print(f"Model context length:{args.max_model_len}")
print("\n----- Generation Parameters -----")
print(f"Temperature: {args.temperature}")
print(f"Top-p: {args.top_p}")
print(f"Max tokens: {args.max_tokens}")
print(f"Repetition penalty: {args.repetition_penalty}")
print("\n----- Settings -----")
print(f"Dataset: {args.dataset}")
print(f"Num narratives (K): {args.num_narratives}")
print("=============================================\n")
def main() -> None:
args = parse_arguments()
display_config(args)
assess_narratives(
model_name=args.worker_model_name,
temperature=args.temperature,
top_p=args.top_p,
dataset=args.dataset,
max_tokens=args.max_tokens,
repetition_penalty=args.repetition_penalty,
max_model_len=args.max_model_len,
num_narratives=args.num_narratives,
)
if __name__ == "__main__":
main()