Differential D9013 Diff 42219 examples/hacker_news/hacker_news_tests/test_resources/test_snowflake_io_manager.py
Changeset View
Changeset View
Standalone View
Standalone View
examples/hacker_news/hacker_news_tests/test_resources/test_snowflake_io_manager.py
import os | import os | ||||
import uuid | import uuid | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from dagster import InputContext, OutputContext | from dagster import build_init_resource_context, build_input_context, build_output_context | ||||
from hacker_news.resources.snowflake_io_manager import ( # pylint: disable=E0401 | from hacker_news.resources.snowflake_io_manager import ( # pylint: disable=E0401 | ||||
SnowflakeIOManager, | |||||
connect_snowflake, | connect_snowflake, | ||||
snowflake_io_manager, | |||||
) | ) | ||||
from pandas import DataFrame | from pandas import DataFrame | ||||
def generate_snowflake_config(): | def generate_snowflake_config(): | ||||
return { | return { | ||||
"account": os.getenv("SNOWFLAKE_ACCOUNT"), | "account": os.getenv("SNOWFLAKE_ACCOUNT"), | ||||
"user": os.getenv("SNOWFLAKE_USER"), | "user": os.getenv("SNOWFLAKE_USER"), | ||||
"password": os.getenv("SNOWFLAKE_PASSWORD"), | "password": os.getenv("SNOWFLAKE_PASSWORD"), | ||||
"database": "DEMO_DB", | "database": "DEMO_DB", | ||||
"warehouse": "TINY_WAREHOUSE", | "warehouse": "TINY_WAREHOUSE", | ||||
} | } | ||||
@contextmanager | @contextmanager | ||||
def temporary_snowflake_table(contents: DataFrame): | def temporary_snowflake_table(contents: DataFrame): | ||||
snowflake_config = generate_snowflake_config() | snowflake_config = generate_snowflake_config() | ||||
table_name = "a" + str(uuid.uuid4()).replace("-", "_") | table_name = "a" + str(uuid.uuid4()).replace("-", "_") | ||||
with connect_snowflake(snowflake_config) as con: | with connect_snowflake(snowflake_config) as con: | ||||
contents.to_sql(name=table_name, con=con, index=False, schema="public") | contents.to_sql(name=table_name, con=con, index=False, schema="public") | ||||
try: | try: | ||||
yield table_name | yield table_name | ||||
finally: | finally: | ||||
with connect_snowflake(snowflake_config) as conn: | with connect_snowflake(snowflake_config) as conn: | ||||
conn.execute(f"drop table public.{table_name}") | conn.execute(f"drop table public.{table_name}") | ||||
def test_handle_output_then_load_input(): | def test_handle_output_then_load_input(): | ||||
snowflake_config = generate_snowflake_config() | snowflake_config = generate_snowflake_config() | ||||
snowflake_manager = SnowflakeIOManager() | snowflake_manager = snowflake_io_manager(build_init_resource_context(config=snowflake_config)) | ||||
contents1 = DataFrame([{"col1": "a", "col2": 1}]) # just to get the types right | contents1 = DataFrame([{"col1": "a", "col2": 1}]) # just to get the types right | ||||
contents2 = DataFrame([{"col1": "b", "col2": 2}]) # contents we will insert | contents2 = DataFrame([{"col1": "b", "col2": 2}]) # contents we will insert | ||||
with temporary_snowflake_table(contents1) as temp_table_name: | with temporary_snowflake_table(contents1) as temp_table_name: | ||||
metadata = { | metadata = { | ||||
"table": f"public.{temp_table_name}", | "table": f"public.{temp_table_name}", | ||||
} | } | ||||
output_context = OutputContext( | output_context = build_output_context(metadata=metadata, resource_config=snowflake_config) | ||||
step_key="a", | |||||
name="result", | |||||
pipeline_name="fake_pipeline", | |||||
metadata=metadata, | |||||
resource_config=snowflake_config, | |||||
) | |||||
list(snowflake_manager.handle_output(output_context, contents2)) # exhaust the iterator | list(snowflake_manager.handle_output(output_context, contents2)) # exhaust the iterator | ||||
input_context = InputContext( | input_context = build_input_context( | ||||
upstream_output=output_context, resource_config=snowflake_config | upstream_output=output_context, resource_config=snowflake_config | ||||
) | ) | ||||
input_value = snowflake_manager.load_input(input_context) | input_value = snowflake_manager.load_input(input_context) | ||||
assert input_value.equals(contents2), f"{input_value}\n\n{contents2}" | assert input_value.equals(contents2), f"{input_value}\n\n{contents2}" |