Skip to content

Commit

Permalink
Revert "⚡️ testing pattern proposal"
Browse files Browse the repository at this point in the history
This reverts commit b660f24.
  • Loading branch information
z3z1ma committed Oct 11, 2022
1 parent b660f24 commit 828e4ab
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 74 deletions.
7 changes: 1 addition & 6 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
{
"python.formatting.provider": "black",
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.formatting.provider": "black"
}
11 changes: 6 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ bottle = "^0.12.23"
orjson = "^3.8.0"
fastapi = "^0.85.0"
uvicorn = { extras = ["standard"], version = "^0.18.3" }
sqlfluff = "^1.3.2"
jinja2-simple-tags = "^0.4.0"
# Streamlit Workbench Dependencies
streamlit = { version = ">=1.0.0", optional = true }
streamlit-ace = { version = ">=0.1.0", optional = true }
Expand All @@ -48,7 +46,9 @@ feedparser = { version = "^6.0.10", optional = true }
duckcli = { version = "^0.2.1", optional = true }
dbt-duckdb = { version = "^1.2.0", optional = true }
dbt-sqlite = { version = "^1.1.3", optional = true }
dbt-postgres = { version = "1.2.1", optional = true }
sqlfluff = "^1.3.2"
jinja2-simple-tags = "^0.4.0"
dbt-postgres = "1.2.1"

[tool.poetry.dev-dependencies]
black = ">=21.9b0"
Expand All @@ -63,7 +63,6 @@ viztracer = "^0.15.3"
[tool.poetry.extras]
duckdb = ["dbt-duckdb", "duckcli"]
sqlite = ["dbt-sqlite"]
postgres = ["dbt-postgres"]
workbench = [
"streamlit",
"streamlit-ace",
Expand Down
13 changes: 1 addition & 12 deletions src/dbt_osmosis/core/osmosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from dbt.adapters.factory import Adapter, get_adapter_class_by_name
from dbt.clients import jinja # monkey-patched for perf
from dbt.config.runtime import RuntimeConfig
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import ManifestNode, MaybeNonSource, MaybeParsedSource, NodeType
from dbt.contracts.graph.parsed import ColumnInfo
Expand Down Expand Up @@ -192,13 +191,7 @@ def __init__(
class DbtAdapterCompilationResult:
"""Interface for compilation results, this keeps us 1 layer removed from dbt interfaces which may change"""

def __init__(
self,
raw_sql: str,
compiled_sql: str,
node: ManifestNode,
injected_sql: Optional[str] = None,
) -> None:
def __init__(self, raw_sql: str, compiled_sql: str, node: ManifestNode, injected_sql: Optional[str] = None) -> None:
self.raw_sql = raw_sql
self.compiled_sql = compiled_sql
self.node = node
Expand Down Expand Up @@ -338,10 +331,6 @@ def _with_conn() -> T:

return _with_conn

def generate_runtime_model_context(self, node: ManifestNode):
"""Wraps dbt context provider"""
return generate_runtime_model_context(node, self.config, self.dbt)

@property
def project_name(self) -> str:
"""dbt project name"""
Expand Down
180 changes: 133 additions & 47 deletions src/dbt_osmosis/dbt_templater/templater.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,37 @@
"""Defines the dbt_osmosis templater."""
import logging

import os.path
import logging
import threading
from pathlib import Path
from typing import Optional

from dbt.clients import jinja
from dbt.exceptions import CompilationException as DbtCompilationException
from dbt.version import get_installed_version
from dbt.exceptions import (
CompilationException as DbtCompilationException,
)
from jinja2 import Environment
from jinja2_simple_tags import StandaloneTag
from sqlfluff.core.config import FluffConfig
from sqlfluff.core.errors import SQLTemplaterError
from sqlfluff.core.errors import SQLTemplaterError, SQLFluffSkipFile
from sqlfluff.core.templaters.base import TemplatedFile, large_file_check
from sqlfluff.core.templaters.jinja import JinjaTemplater

from dbt_osmosis.core.osmosis import DbtProjectContainer
from dbt_osmosis.core.osmosis import DbtProject

templater_logger = logging.getLogger(__name__)

DBT_VERSION = get_installed_version()
DBT_VERSION_TUPLE = (int(DBT_VERSION.major), int(DBT_VERSION.minor))

if DBT_VERSION_TUPLE >= (1, 3):
COMPILED_SQL_ATTRIBUTE = "compiled_code"
RAW_SQL_ATTRIBUTE = "raw_code"
else:
COMPILED_SQL_ATTRIBUTE = "compiled_sql"
RAW_SQL_ATTRIBUTE = "raw_sql"

local = threading.local()


class OsmosisDbtTemplater(JinjaTemplater):
"""dbt templater for dbt-osmosis, based on sqlfluff-templater-dbt."""
Expand All @@ -26,25 +42,20 @@ class OsmosisDbtTemplater(JinjaTemplater):
name = "dbt"

def __init__(self, **kwargs):
self.dbt_project_container: DbtProjectContainer = kwargs.pop("dbt_project_container")
self.dbt_project_container = kwargs.pop("dbt_project_container")
super().__init__(**kwargs)

def config_pairs(self): # pragma: no cover
"""Returns info about the given templater for output by the cli."""
return [("templater", self.name), ("dbt", DBT_VERSION.to_version_string())]

@large_file_check
def process(
self,
*,
in_str: str,
fname: Optional[str] = None,
config: Optional[FluffConfig] = None,
**kwargs,
):
def process(self, *, fname: str, in_str=None, config, **kwargs):
"""Compile a dbt model and return the compiled SQL."""
try:
return self._unsafe_process(os.path.abspath(fname) if fname else None, in_str, config)
return self._unsafe_process(
os.path.abspath(fname) if in_str is None else None, in_str, config
)
except DbtCompilationException as e:
if e.node:
return None, [
Expand All @@ -58,64 +69,133 @@ def process(
except SQLTemplaterError as e: # pragma: no cover
return None, [e]

def _unsafe_process(self, fname: Optional[str], in_str: str, config: FluffConfig = None):
# Get project
osmosis_dbt_project = self.dbt_project_container.get_project_by_root_dir(
# from .sqlfluff templater project_dir
def _find_node(self, project: DbtProject, fname: str):
expected_node_path = os.path.relpath(fname, start=os.path.abspath(project.args.project_dir))
node = project.get_node_by_path(expected_node_path)
if node:
return node
skip_reason = self._find_skip_reason(project, expected_node_path)
if skip_reason:
raise SQLFluffSkipFile(f"Skipped file {fname} because it is {skip_reason}")
raise SQLFluffSkipFile(f"File {fname} was not found in dbt project") # pragma: no cover

@staticmethod
def _find_skip_reason(project: DbtProject, expected_node_path: str) -> Optional[str]:
"""Return string reason if model okay to skip, otherwise None."""
# Scan macros.
for macro in project.dbt.macros.values():
if macro.original_file_path == expected_node_path:
return "a macro"

# Scan disabled nodes.
for nodes in project.dbt.disabled.values():
for node in nodes:
if node.original_file_path == expected_node_path:
return "disabled"
return None

@staticmethod
def from_string(*args, **kwargs):
"""Replaces (via monkeypatch) the jinja2.Environment function."""
globals = kwargs.get("globals")
if globals and hasattr(local, "target_sql"):
model = globals.get("model")
if model:
# Is it processing the node we're interested in?
if isinstance(local.target_sql, Path):
the_one = str(local.target_sql) == model.get("original_file_path")
else:
the_one = local.target_sql == args[1]
if the_one:
# Yes. Capture the important arguments and create
# a make_template() function.
env = args[0]
globals = args[2] if len(args) >= 3 else kwargs["globals"]

def make_template(in_str):
env.add_extension(SnapshotExtension)
return env.from_string(in_str, globals=globals)

local.make_template = make_template

return old_from_string(*args, **kwargs)

def _unsafe_process(self, fname, in_str, config=None):
osmosis_dbt_project = self.dbt_project_container().get_project_by_root_dir(
config.get_section((self.templater_selector, self.name, "project_dir"))
)
local.make_template = None
try:
if fname:
node = self._find_node(osmosis_dbt_project, fname)
local.target_sql = Path(
os.path.relpath(fname, start=osmosis_dbt_project.args.project_dir)
)
compiled_node = osmosis_dbt_project.compile_node(node)
else:
local.target_sql = in_str
# TRICKY: Use __wrapped__ to bypass the cache. We *must*
# recompile each time, because that's how we get the
# make_template() function.
compiled_node = osmosis_dbt_project.compile_sql.__wrapped__(
osmosis_dbt_project, in_str
)
except Exception as err:
raise SQLFluffSkipFile( # pragma: no cover
f"Skipped file {fname} because dbt raised a fatal "
f"exception during compilation: {err!s}"
) from err
finally:
local.target_sql = None

if compiled_node.injected_sql:
# If injected SQL is present, it contains a better picture
# of what will actually hit the database (e.g. with tests).
# However it's not always present.
compiled_sql = compiled_node.injected_sql
else:
compiled_sql = compiled_node.compiled_sql

raw_sql = compiled_node.raw_sql

# Use path if valid, prioritize it as the in_str
source_dbt_sql = str(in_str)
fpath = Path(fname)
if fpath.exists():
source_dbt_sql = fpath.read_text()
in_str = str(source_dbt_sql or in_str)

# Generate node
mock_node = osmosis_dbt_project.get_server_node(in_str, fname)
resp = osmosis_dbt_project.compile_node(mock_node)

# Generate context
ctx = osmosis_dbt_project.generate_runtime_model_context(resp.node)
env = jinja.get_environment(resp.node)
env.add_extension(SnapshotExtension)
compiled_sql = resp.compiled_sql
make_template = lambda _in_str: env.from_string(_in_str, globals=ctx)

# Need compiled
if not compiled_sql: # pragma: no cover
raise SQLTemplaterError(
"dbt templater compilation failed silently, check your "
"configuration by running `dbt compile` directly."
)

# Whitespace
if fname:
with open(fname) as source_dbt_model:
source_dbt_sql = source_dbt_model.read()
else:
source_dbt_sql = in_str

if not source_dbt_sql.rstrip().endswith("-%}"):
n_trailing_newlines = len(source_dbt_sql) - len(source_dbt_sql.rstrip("\n"))
else:
# Source file ends with right whitespace stripping, so there's
# no need to preserve/restore trailing newlines.
n_trailing_newlines = 0

# LOG
templater_logger.debug(
" Trailing newline count in source dbt model: %r",
n_trailing_newlines,
)
templater_logger.debug(" Raw SQL before compile: %r", source_dbt_sql)
templater_logger.debug(" Node raw SQL: %r", in_str)
templater_logger.debug(" Node raw SQL: %r", raw_sql)
templater_logger.debug(" Node compiled SQL: %r", compiled_sql)

# SLICE
# Adjust for dbt Jinja removing trailing newlines. For more details on
# this, see the similar code in sqlfluff-templater.dbt.
compiled_node.raw_sql = source_dbt_sql
compiled_sql = compiled_sql + "\n" * n_trailing_newlines
raw_sliced, sliced_file, templated_sql = self.slice_file(
raw_str=source_dbt_sql,
templated_str=compiled_sql + "\n" * n_trailing_newlines,
source_dbt_sql,
compiled_sql,
config=config,
make_template=make_template,
make_template=local.make_template,
append_to_templated="\n" if n_trailing_newlines else "",
)

return (
TemplatedFile(
source_str=source_dbt_sql,
Expand All @@ -129,6 +209,12 @@ def _unsafe_process(self, fname: Optional[str], in_str: str, config: FluffConfig
)


# Monkeypatch Environment.from_string(). OsmosisDbtTemplater uses this to
# intercept Jinja compilation and capture a template trace.
old_from_string = Environment.from_string
Environment.from_string = OsmosisDbtTemplater.from_string


class SnapshotExtension(StandaloneTag):
"""Dummy "snapshot" tags so raw dbt templates will parse.
Expand Down

0 comments on commit 828e4ab

Please sign in to comment.