Skip to content

Commit eb36bb7

Browse files
committed
Improve heap safety in allocator
1 parent 3779c5a commit eb36bb7

File tree

1 file changed

+91
-52
lines changed

1 file changed

+91
-52
lines changed

Client/multiplayer_sa/CMultiplayerSA_FixMallocAlign.cpp

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ namespace mta::memory
2020
constexpr std::uint32_t NULL_PAGE_BOUNDARY = 0x10000;
2121
constexpr std::uint32_t MAX_ADDRESS_SPACE = 0xFFFFFFFF;
2222
constexpr std::uint32_t POINTER_SIZE = 4;
23+
constexpr std::uint32_t POINTER_METADATA_OVERHEAD = POINTER_SIZE * 2;
24+
constexpr std::uint32_t METADATA_MAGIC = 0x4D544100; // 'MTA\0'
25+
constexpr std::uint32_t METADATA_MAGIC_MASK = 0xFFFFFFFE;
26+
constexpr std::uint32_t METADATA_FLAG_VIRTUALALLOC = 0x1;
2327

2428
constexpr bool is_valid_alignment(std::size_t alignment) noexcept
2529
{
@@ -28,7 +32,7 @@ namespace mta::memory
2832

2933
void* SafeMallocAlignVirtual(std::size_t size, std::size_t alignment) noexcept;
3034

31-
// Aligned malloc - stores pointer at result-4
35+
// Aligned malloc - stores pointer at result-4 and metadata at result-8
3236
[[nodiscard]] void* SafeMallocAlign(std::size_t size, std::size_t alignment) noexcept
3337
{
3438
// Check alignment
@@ -56,18 +60,19 @@ namespace mta::memory
5660
const std::uint32_t align_u32 = static_cast<std::uint32_t>(alignment);
5761

5862
// Prevent intermediate overflow
59-
if (size_u32 > UINT32_MAX - align_u32)
63+
if (align_u32 > UINT32_MAX - POINTER_METADATA_OVERHEAD)
6064
{
6165
errno = ENOMEM;
6266
return nullptr;
6367
}
64-
// Now safe to add size_u32 + align_u32
65-
if (size_u32 + align_u32 > UINT32_MAX - POINTER_SIZE)
68+
const std::uint32_t alignment_overhead = align_u32 + POINTER_METADATA_OVERHEAD;
69+
70+
if (size_u32 > UINT32_MAX - alignment_overhead)
6671
{
6772
errno = ENOMEM;
6873
return nullptr;
6974
}
70-
const std::uint32_t total_size = size_u32 + align_u32 + POINTER_SIZE;
75+
const std::uint32_t total_size = size_u32 + alignment_overhead;
7176

7277
void* raw_memory = malloc(total_size);
7378
if (!raw_memory)
@@ -78,16 +83,16 @@ namespace mta::memory
7883

7984
const std::uint32_t raw_addr = reinterpret_cast<std::uint32_t>(raw_memory);
8085

81-
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_SIZE - align_u32 + 1)
86+
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_METADATA_OVERHEAD - align_u32 + 1)
8287
{
8388
free(raw_memory);
8489
errno = ENOMEM;
8590
return nullptr;
8691
}
8792

88-
const std::uint32_t aligned_addr = (raw_addr + POINTER_SIZE + align_u32 - 1) & ~(align_u32 - 1);
93+
const std::uint32_t aligned_addr = (raw_addr + POINTER_METADATA_OVERHEAD + align_u32 - 1) & ~(align_u32 - 1);
8994

90-
if (aligned_addr < raw_addr + POINTER_SIZE || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
95+
if (aligned_addr < raw_addr + POINTER_METADATA_OVERHEAD || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
9196
{
9297
free(raw_memory);
9398
errno = EINVAL;
@@ -98,15 +103,20 @@ namespace mta::memory
98103

99104
// Validate store location
100105
void** store_location = reinterpret_cast<void**>(aligned_addr - POINTER_SIZE);
101-
if (reinterpret_cast<std::uint32_t>(store_location) < raw_addr ||
102-
reinterpret_cast<std::uint32_t>(store_location) > raw_addr + total_size - POINTER_SIZE)
106+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(aligned_addr - POINTER_METADATA_OVERHEAD);
107+
const std::uint32_t store_addr = reinterpret_cast<std::uint32_t>(store_location);
108+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
109+
110+
if (store_addr < raw_addr || store_addr > raw_addr + total_size - POINTER_SIZE ||
111+
metadata_addr < raw_addr || metadata_addr > raw_addr + total_size - POINTER_SIZE)
103112
{
104113
free(raw_memory);
105114
errno = EFAULT;
106115
return nullptr;
107116
}
108117

109118
*store_location = raw_memory;
119+
*metadata_location = METADATA_MAGIC;
110120

111121
return result;
112122
}
@@ -127,18 +137,19 @@ namespace mta::memory
127137
const std::uint32_t align_u32 = static_cast<std::uint32_t>(alignment);
128138

129139
// Prevent intermediate overflow
130-
if (size_u32 > UINT32_MAX - align_u32)
140+
if (align_u32 > UINT32_MAX - POINTER_METADATA_OVERHEAD)
131141
{
132142
errno = ENOMEM;
133143
return nullptr;
134144
}
135-
// Now safe to add size_u32 + align_u32
136-
if (size_u32 + align_u32 > UINT32_MAX - POINTER_SIZE)
145+
const std::uint32_t alignment_overhead = align_u32 + POINTER_METADATA_OVERHEAD;
146+
147+
if (size_u32 > UINT32_MAX - alignment_overhead)
137148
{
138149
errno = ENOMEM;
139150
return nullptr;
140151
}
141-
const std::uint32_t total_size = size_u32 + align_u32 + POINTER_SIZE;
152+
const std::uint32_t total_size = size_u32 + alignment_overhead;
142153

143154
void* raw_memory = malloc(total_size);
144155
if (!raw_memory)
@@ -149,16 +160,16 @@ namespace mta::memory
149160

150161
const std::uint32_t raw_addr = reinterpret_cast<std::uint32_t>(raw_memory);
151162

152-
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_SIZE - align_u32 + 1)
163+
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_METADATA_OVERHEAD - align_u32 + 1)
153164
{
154165
free(raw_memory);
155166
errno = ENOMEM;
156167
return nullptr;
157168
}
158169

159-
const std::uint32_t aligned_addr = (raw_addr + POINTER_SIZE + align_u32 - 1) & ~(align_u32 - 1);
170+
const std::uint32_t aligned_addr = (raw_addr + POINTER_METADATA_OVERHEAD + align_u32 - 1) & ~(align_u32 - 1);
160171

161-
if (aligned_addr < raw_addr + POINTER_SIZE || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
172+
if (aligned_addr < raw_addr + POINTER_METADATA_OVERHEAD || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
162173
{
163174
free(raw_memory);
164175
errno = EINVAL;
@@ -168,15 +179,20 @@ namespace mta::memory
168179
void* result = reinterpret_cast<void*>(aligned_addr);
169180

170181
void** store_location = reinterpret_cast<void**>(aligned_addr - POINTER_SIZE);
171-
if (reinterpret_cast<std::uint32_t>(store_location) < raw_addr ||
172-
reinterpret_cast<std::uint32_t>(store_location) > raw_addr + total_size - POINTER_SIZE)
182+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(aligned_addr - POINTER_METADATA_OVERHEAD);
183+
const std::uint32_t store_addr = reinterpret_cast<std::uint32_t>(store_location);
184+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
185+
186+
if (store_addr < raw_addr || store_addr > raw_addr + total_size - POINTER_SIZE ||
187+
metadata_addr < raw_addr || metadata_addr > raw_addr + total_size - POINTER_SIZE)
173188
{
174189
free(raw_memory);
175190
errno = EFAULT;
176191
return nullptr;
177192
}
178193

179194
*store_location = raw_memory;
195+
*metadata_location = METADATA_MAGIC;
180196

181197
return result;
182198
}
@@ -206,23 +222,25 @@ namespace mta::memory
206222
const std::uint32_t align_u32 = static_cast<std::uint32_t>(alignment);
207223
const std::uint32_t padding = (align_u32 <= 64) ? 32 : VIRTUALALLOC_PADDING;
208224

209-
if (align_u32 > UINT32_MAX - POINTER_SIZE)
225+
if (align_u32 > UINT32_MAX - POINTER_METADATA_OVERHEAD)
210226
{
211227
errno = ENOMEM;
212228
return nullptr;
213229
}
214-
if (align_u32 + POINTER_SIZE > UINT32_MAX - padding)
230+
const std::uint32_t alignment_overhead = align_u32 + POINTER_METADATA_OVERHEAD;
231+
232+
if (alignment_overhead > UINT32_MAX - padding)
215233
{
216234
errno = ENOMEM;
217235
return nullptr;
218236
}
219-
if (size_u32 > UINT32_MAX - align_u32 - POINTER_SIZE - padding)
237+
if (size_u32 > UINT32_MAX - alignment_overhead - padding)
220238
{
221239
errno = ENOMEM;
222240
return nullptr;
223241
}
224242

225-
const DWORD total_size = size_u32 + align_u32 + POINTER_SIZE + padding;
243+
const DWORD total_size = size_u32 + alignment_overhead + padding;
226244

227245
void* raw_ptr = VirtualAlloc(nullptr, static_cast<SIZE_T>(total_size), MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
228246
if (!raw_ptr)
@@ -233,17 +251,17 @@ namespace mta::memory
233251

234252
const std::uint32_t raw_addr = reinterpret_cast<std::uint32_t>(raw_ptr);
235253

236-
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_SIZE - align_u32 + 1)
254+
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_METADATA_OVERHEAD - align_u32 + 1)
237255
{
238256
BOOL vfree_result = VirtualFree(raw_ptr, 0, MEM_RELEASE);
239257
(void)vfree_result;
240258
errno = ENOMEM;
241259
return nullptr;
242260
}
243261

244-
const std::uint32_t aligned_addr = (raw_addr + POINTER_SIZE + align_u32 - 1) & ~(align_u32 - 1);
262+
const std::uint32_t aligned_addr = (raw_addr + POINTER_METADATA_OVERHEAD + align_u32 - 1) & ~(align_u32 - 1);
245263

246-
if (aligned_addr < raw_addr + POINTER_SIZE || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
264+
if (aligned_addr < raw_addr + POINTER_METADATA_OVERHEAD || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
247265
{
248266
BOOL vfree_result = VirtualFree(raw_ptr, 0, MEM_RELEASE);
249267
(void)vfree_result;
@@ -255,8 +273,12 @@ namespace mta::memory
255273

256274
// Validate store location
257275
void** store_location = reinterpret_cast<void**>(aligned_addr - POINTER_SIZE);
258-
if (reinterpret_cast<std::uint32_t>(store_location) < raw_addr ||
259-
reinterpret_cast<std::uint32_t>(store_location) > raw_addr + total_size - POINTER_SIZE)
276+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(aligned_addr - POINTER_METADATA_OVERHEAD);
277+
const std::uint32_t store_addr = reinterpret_cast<std::uint32_t>(store_location);
278+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
279+
280+
if (store_addr < raw_addr || store_addr > raw_addr + total_size - POINTER_SIZE ||
281+
metadata_addr < raw_addr || metadata_addr > raw_addr + total_size - POINTER_SIZE)
260282
{
261283
BOOL vfree_result = VirtualFree(raw_ptr, 0, MEM_RELEASE);
262284
(void)vfree_result;
@@ -265,6 +287,7 @@ namespace mta::memory
265287
}
266288

267289
*store_location = raw_ptr;
290+
*metadata_location = METADATA_MAGIC | METADATA_FLAG_VIRTUALALLOC;
268291

269292
return result;
270293
}
@@ -282,7 +305,13 @@ namespace mta::memory
282305
return;
283306
}
284307

308+
if (ptr_addr < POINTER_METADATA_OVERHEAD)
309+
{
310+
return;
311+
}
312+
285313
void** read_location = reinterpret_cast<void**>(ptr_addr - POINTER_SIZE);
314+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(ptr_addr - POINTER_METADATA_OVERHEAD);
286315

287316
// Validate memory readable
288317
MEMORY_BASIC_INFORMATION mbi_read;
@@ -293,7 +322,35 @@ namespace mta::memory
293322
return;
294323
}
295324

325+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
326+
const std::uint32_t base_addr = reinterpret_cast<std::uint32_t>(mbi_read.BaseAddress);
327+
328+
if (mbi_read.RegionSize == 0 || mbi_read.RegionSize > static_cast<SIZE_T>(MAX_ADDRESS_SPACE))
329+
{
330+
return;
331+
}
332+
333+
const std::uint32_t region_size_u32 = static_cast<std::uint32_t>(mbi_read.RegionSize);
334+
335+
if (base_addr > MAX_ADDRESS_SPACE - region_size_u32)
336+
{
337+
return;
338+
}
339+
340+
const std::uint32_t region_end = base_addr + region_size_u32;
341+
342+
if (region_size_u32 < POINTER_SIZE || metadata_addr < base_addr || metadata_addr > region_end - POINTER_SIZE)
343+
{
344+
return;
345+
}
346+
296347
void* original_ptr = *read_location;
348+
const std::uint32_t metadata = *metadata_location;
349+
350+
if ((metadata & METADATA_MAGIC_MASK) != METADATA_MAGIC)
351+
{
352+
return;
353+
}
297354

298355
if (!original_ptr)
299356
{
@@ -308,12 +365,12 @@ namespace mta::memory
308365
}
309366

310367
const std::uint32_t distance = ptr_addr - original_addr;
311-
if (distance > MAX_ALIGNMENT + POINTER_SIZE)
368+
if (distance > MAX_ALIGNMENT + POINTER_METADATA_OVERHEAD)
312369
{
313370
return; // Beyond maximum possible alignment
314371
}
315372

316-
if (ptr_addr < POINTER_SIZE || original_addr > ptr_addr - POINTER_SIZE)
373+
if (original_addr > ptr_addr - POINTER_SIZE)
317374
{
318375
return; // Violates our storage pattern
319376
}
@@ -323,31 +380,13 @@ namespace mta::memory
323380
return;
324381
}
325382

326-
MEMORY_BASIC_INFORMATION mbi;
327-
SIZE_T mbi_result = VirtualQuery(original_ptr, &mbi, sizeof(mbi));
328-
329-
if (mbi_result == sizeof(mbi))
383+
if ((metadata & METADATA_FLAG_VIRTUALALLOC) != 0)
330384
{
331-
const std::uint32_t base_addr = reinterpret_cast<std::uint32_t>(mbi.AllocationBase);
332-
333-
// Validate region size
334-
if (mbi.RegionSize > 0 && mbi.RegionSize <= static_cast<SIZE_T>(MAX_ADDRESS_SPACE) &&
335-
base_addr <= MAX_ADDRESS_SPACE - static_cast<std::uint32_t>(mbi.RegionSize))
336-
{
337-
const std::uint32_t region_size_u32 = static_cast<std::uint32_t>(mbi.RegionSize);
338-
const std::uint32_t region_end = base_addr + region_size_u32;
339-
340-
// Use VirtualFree if matches
341-
if (mbi.Type == MEM_PRIVATE && mbi.State == MEM_COMMIT && original_addr >= base_addr && original_addr < region_end)
342-
{
343-
BOOL vfree_result = VirtualFree(mbi.AllocationBase, 0, MEM_RELEASE);
344-
(void)vfree_result;
345-
return;
346-
}
347-
}
385+
BOOL vfree_result = VirtualFree(original_ptr, 0, MEM_RELEASE);
386+
(void)vfree_result;
387+
return;
348388
}
349389

350-
// Use free for malloc
351390
free(original_ptr);
352391
}
353392
} // namespace mta::memory

0 commit comments

Comments
 (0)