|
3 | 3 |
|
4 | 4 | from lite_bootstrap import import_checker |
5 | 5 | from lite_bootstrap.bootstrappers.base import BaseBootstrapper |
| 6 | +from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument |
6 | 7 | from lite_bootstrap.instruments.healthchecks_instrument import ( |
7 | 8 | HealthChecksConfig, |
8 | 9 | HealthChecksInstrument, |
|
22 | 23 | if import_checker.is_litestar_installed: |
23 | 24 | import litestar |
24 | 25 | from litestar.config.app import AppConfig |
| 26 | + from litestar.config.cors import CORSConfig |
25 | 27 | from litestar.contrib.opentelemetry import OpenTelemetryConfig |
26 | 28 | from litestar.plugins.prometheus import PrometheusConfig, PrometheusController |
27 | 29 |
|
|
31 | 33 |
|
32 | 34 | @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) |
33 | 35 | class LitestarConfig( |
34 | | - HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig |
| 36 | + CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig |
35 | 37 | ): |
36 | 38 | application_config: "AppConfig" = dataclasses.field(default_factory=lambda: AppConfig()) |
37 | 39 | opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list) |
38 | 40 | prometheus_additional_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict) |
39 | 41 |
|
40 | 42 |
|
| 43 | +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) |
| 44 | +class LitestarCorsInstrument(CorsInstrument): |
| 45 | + bootstrap_config: LitestarConfig |
| 46 | + |
| 47 | + def bootstrap(self) -> None: |
| 48 | + self.bootstrap_config.application_config.cors_config = CORSConfig( |
| 49 | + allow_origins=self.bootstrap_config.cors_allowed_origins, |
| 50 | + allow_methods=self.bootstrap_config.cors_allowed_methods, # type: ignore[arg-type] |
| 51 | + allow_headers=self.bootstrap_config.cors_allowed_headers, |
| 52 | + allow_credentials=self.bootstrap_config.cors_allowed_credentials, |
| 53 | + allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex, |
| 54 | + expose_headers=self.bootstrap_config.cors_exposed_headers, |
| 55 | + max_age=self.bootstrap_config.cors_max_age, |
| 56 | + ) |
| 57 | + |
| 58 | + |
41 | 59 | @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) |
42 | 60 | class LitestarHealthChecksInstrument(HealthChecksInstrument): |
43 | 61 | bootstrap_config: LitestarConfig |
@@ -116,6 +134,7 @@ class LitestarBootstrapper(BaseBootstrapper["litestar.Litestar"]): |
116 | 134 | __slots__ = "bootstrap_config", "instruments" |
117 | 135 |
|
118 | 136 | instruments_types: typing.ClassVar = [ |
| 137 | + LitestarCorsInstrument, |
119 | 138 | LitestarOpenTelemetryInstrument, |
120 | 139 | LitestarSentryInstrument, |
121 | 140 | LitestarHealthChecksInstrument, |
|
0 commit comments