Changeset View
Changeset View
Standalone View
Standalone View
examples/airline_demo/airline_demo_tests/test_types.py
Show First 20 Lines • Show All 131 Lines • ▼ Show 20 Lines | try: | ||||
Bucket=intermediate_storage.object_store.bucket, Key=success_key | Bucket=intermediate_storage.object_store.bucket, Key=success_key | ||||
) | ) | ||||
except botocore.exceptions.ClientError: | except botocore.exceptions.ClientError: | ||||
raise Exception("Couldn't find object at {success_key}".format(success_key=success_key)) | raise Exception("Couldn't find object at {success_key}".format(success_key=success_key)) | ||||
def test_spark_dataframe_output_csv(): | def test_spark_dataframe_output_csv(): | ||||
spark = SparkSession.builder.getOrCreate() | spark = SparkSession.builder.getOrCreate() | ||||
num_df = ( | |||||
spark.read.format("csv") | |||||
.options(header="true", inferSchema="true") | |||||
.load(file_relative_path(__file__, "num.csv")) | |||||
) | |||||
@solid( | |||||
input_defs=[InputDefinition("num_df", DataFrame)], | |||||
output_defs=[OutputDefinition(DataFrame)], | |||||
required_resource_keys={"pyspark"}, | |||||
) | |||||
def emit(_, num_df): | |||||
assert num_df.collect() == [Row(num1=1, num2=2)] | assert num_df.collect() == [Row(num1=1, num2=2)] | ||||
@solid | |||||
def emit(_): | |||||
return num_df | return num_df | ||||
@solid(input_defs=[InputDefinition("df", DataFrame)], output_defs=[OutputDefinition(DataFrame)]) | @solid( | ||||
input_defs=[InputDefinition("df", DataFrame)], | |||||
output_defs=[OutputDefinition(DataFrame)], | |||||
required_resource_keys={"pyspark"}, | |||||
) | |||||
def passthrough_df(_context, df): | def passthrough_df(_context, df): | ||||
return df | return df | ||||
@pipeline | @pipeline(mode_defs=[ModeDefinition(resource_defs={"pyspark": pyspark_resource})]) | ||||
def passthrough(): | def passthrough(): | ||||
passthrough_df(emit()) | passthrough_df(emit()) | ||||
with seven.TemporaryDirectory() as tempdir: | with seven.TemporaryDirectory() as tempdir: | ||||
file_name = os.path.join(tempdir, "output.csv") | file_name = os.path.join(tempdir, "output.csv") | ||||
result = execute_pipeline( | result = execute_pipeline( | ||||
passthrough, | passthrough, | ||||
run_config={ | run_config={ | ||||
"solids": { | "solids": { | ||||
"passthrough_df": { | "emit": { | ||||
"outputs": [{"result": {"csv": {"path": file_name, "header": True}}}] | "inputs": { | ||||
"num_df": { | |||||
"csv": { | |||||
"path": file_relative_path(__file__, "num.csv"), | |||||
"header": True, | |||||
"inferSchema": True, | |||||
} | } | ||||
} | |||||
} | |||||
}, | |||||
"passthrough_df": { | |||||
"outputs": [{"result": {"csv": {"path": file_name, "header": True}}}], | |||||
}, | |||||
}, | }, | ||||
}, | }, | ||||
) | ) | ||||
from_file_df = ( | from_file_df = spark.read.format("csv").options(header="true").load(file_name) | ||||
spark.read.format("csv").options(header="true", inferSchema="true").load(file_name) | |||||
) | |||||
assert ( | assert ( | ||||
result.result_for_solid("passthrough_df").output_value().collect() | result.result_for_solid("passthrough_df").output_value().collect() | ||||
== from_file_df.collect() | == from_file_df.collect() | ||||
) | ) |