Skip to content

Commit

Permalink
Merge pull request #414 from jnu/spydantic2
Browse files Browse the repository at this point in the history
Support for Pydantic 2 & additional python versions
  • Loading branch information
koxudaxi authored Jul 2, 2024
2 parents d09be6a + c6ea807 commit 2f41065
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,22 @@ Options:
-r, --generate-routers Generate modular api with multiple routers using RouterAPI (for bigger applications).
--specify-tags Use along with --generate-routers to generate specific routers from given list of tags.
-c, --custom-visitors PATH - A custom visitor that adds variables to the template.
-d, --output-model-type Specify a Pydantic base model to use (see [datamodel-code-generator](https://github.com/koxudaxi/datamodel-code-generator); default is `pydantic.BaseModel`).
-p, --python-version Specify a Python version to target (default is `3.8`).
--install-completion Install completion for the current shell.
--show-completion Show completion for the current shell, to copy it
or customize the installation.
--help Show this message and exit.
```

### Pydantic 2 support

Specify the Pydantic 2 `BaseModel` version in the command line, for example:

```sh
$ fastapi-codegen --input api.yaml --output app --output-model-type pydantic_v2.BaseModel
```

## Example
### OpenAPI
```sh
Expand Down
47 changes: 27 additions & 20 deletions fastapi_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Any, Dict, List, Optional

import typer
from datamodel_code_generator import LiteralType, PythonVersion, chdir
from datamodel_code_generator import DataModelType, LiteralType, PythonVersion, chdir
from datamodel_code_generator.format import CodeFormatter
from datamodel_code_generator.imports import Import, Imports
from datamodel_code_generator.model import get_data_model_types
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import DataType
from jinja2 import Environment, FileSystemLoader
Expand Down Expand Up @@ -57,6 +58,12 @@ def main(
None, "--custom-visitor", "-c"
),
disable_timestamp: bool = typer.Option(False, "--disable-timestamp"),
output_model_type: DataModelType = typer.Option(
DataModelType.PydanticBaseModel.value, "--output-model-type", "-d"
),
python_version: PythonVersion = typer.Option(
PythonVersion.PY_38.value, "--python-version", "-p"
),
) -> None:
input_name: str = input_file
input_text: str
Expand All @@ -69,31 +76,20 @@ def main(
else:
model_path = MODEL_PATH

if enum_field_as_literal:
return generate_code(
input_name,
input_text,
encoding,
output_dir,
template_dir,
model_path,
enum_field_as_literal, # type: ignore[arg-type]
custom_visitors=custom_visitors,
disable_timestamp=disable_timestamp,
generate_routers=generate_routers,
specify_tags=specify_tags,
)
return generate_code(
input_name,
input_text,
encoding,
output_dir,
template_dir,
model_path,
enum_field_as_literal=enum_field_as_literal or None,
custom_visitors=custom_visitors,
disable_timestamp=disable_timestamp,
generate_routers=generate_routers,
specify_tags=specify_tags,
output_model_type=output_model_type,
python_version=python_version,
)


Expand All @@ -119,6 +115,8 @@ def generate_code(
disable_timestamp: bool = False,
generate_routers: Optional[bool] = None,
specify_tags: Optional[str] = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
python_version: PythonVersion = PythonVersion.PY_38,
) -> None:
if not model_path:
model_path = MODEL_PATH
Expand All @@ -130,10 +128,19 @@ def generate_code(
template_dir = (
BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR
)
if enum_field_as_literal:
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) # type: ignore[arg-type]
else:
parser = OpenAPIParser(input_text)

data_model_types = get_data_model_types(output_model_type, python_version)

parser = OpenAPIParser(
input_text,
enum_field_as_literal=enum_field_as_literal,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
)

with chdir(output_dir):
models = parser.parse()
output = output_dir / model_path
Expand All @@ -153,7 +160,7 @@ def generate_code(
)

results: Dict[Path, str] = {}
code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
code_formatter = CodeFormatter(python_version, Path().resolve())

template_vars: Dict[str, object] = {"info": parser.parse_info()}
visitors: List[Visitor] = []
Expand Down

0 comments on commit 2f41065

Please sign in to comment.