Skip to content

Commit d5f5d24

Browse files
authored
Add to_arrow to get a pyarrow.Table from query results. (googleapis#8609)
* Add `to_arrow` to get a `pyarrow.Table` from query results. An Arrow `Table` supports a richer set of types than a pandas `DataFrame`, and is the basis of many data analysis systems. It can be used in conjunction with pandas through the `Table.to_pandas()` method or the pandas extension types provided by the `fletcher` package. * Exclude pyarrow 0.14.0 due to bad manylinux wheels.
1 parent 5f9c090 commit d5f5d24

File tree

10 files changed

+753
-33
lines changed

10 files changed

+753
-33
lines changed

bigquery/google/cloud/bigquery/_helpers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def _field_to_index_mapping(schema):
197197
return {f.name: i for i, f in enumerate(schema)}
198198

199199

200+
def _field_from_json(resource, field):
201+
converter = _CELLDATA_FROM_JSON.get(field.field_type, lambda value, _: value)
202+
if field.mode == "REPEATED":
203+
return [converter(item["v"], field) for item in resource]
204+
else:
205+
return converter(resource, field)
206+
207+
200208
def _row_tuple_from_json(row, schema):
201209
"""Convert JSON row data to row with appropriate types.
202210
@@ -214,12 +222,7 @@ def _row_tuple_from_json(row, schema):
214222
"""
215223
row_data = []
216224
for field, cell in zip(schema, row["f"]):
217-
converter = _CELLDATA_FROM_JSON[field.field_type]
218-
if field.mode == "REPEATED":
219-
row_data.append([converter(item["v"], field) for item in cell["v"]])
220-
else:
221-
row_data.append(converter(cell["v"], field))
222-
225+
row_data.append(_field_from_json(cell["v"], field))
223226
return tuple(row_data)
224227

225228

bigquery/google/cloud/bigquery/_pandas_helpers.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Shared helper functions for connecting BigQuery and pandas."""
1616

17-
import collections
1817
import concurrent.futures
1918
import warnings
2019

@@ -115,7 +114,7 @@ def bq_to_arrow_data_type(field):
115114
"""
116115
if field.mode is not None and field.mode.upper() == "REPEATED":
117116
inner_type = bq_to_arrow_data_type(
118-
schema.SchemaField(field.name, field.field_type)
117+
schema.SchemaField(field.name, field.field_type, fields=field.fields)
119118
)
120119
if inner_type:
121120
return pyarrow.list_(inner_type)
@@ -144,6 +143,21 @@ def bq_to_arrow_field(bq_field):
144143
return None
145144

146145

146+
def bq_to_arrow_schema(bq_schema):
147+
"""Return the Arrow schema, corresponding to a given BigQuery schema.
148+
149+
Returns None if any Arrow type cannot be determined.
150+
"""
151+
arrow_fields = []
152+
for bq_field in bq_schema:
153+
arrow_field = bq_to_arrow_field(bq_field)
154+
if arrow_field is None:
155+
# Auto-detect the schema if there is an unknown field type.
156+
return None
157+
arrow_fields.append(arrow_field)
158+
return pyarrow.schema(arrow_fields)
159+
160+
147161
def bq_to_arrow_array(series, bq_field):
148162
arrow_type = bq_to_arrow_data_type(bq_field)
149163
if bq_field.mode.upper() == "REPEATED":
@@ -210,13 +224,41 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath):
210224
pyarrow.parquet.write_table(arrow_table, filepath)
211225

212226

227+
def _tabledata_list_page_to_arrow(page, column_names, arrow_types):
228+
# Iterate over the page to force the API request to get the page data.
229+
try:
230+
next(iter(page))
231+
except StopIteration:
232+
pass
233+
234+
arrays = []
235+
for column_index, arrow_type in enumerate(arrow_types):
236+
arrays.append(pyarrow.array(page._columns[column_index], type=arrow_type))
237+
238+
return pyarrow.RecordBatch.from_arrays(arrays, column_names)
239+
240+
241+
def download_arrow_tabledata_list(pages, schema):
242+
"""Use tabledata.list to construct an iterable of RecordBatches."""
243+
column_names = bq_to_arrow_schema(schema) or [field.name for field in schema]
244+
arrow_types = [bq_to_arrow_data_type(field) for field in schema]
245+
246+
for page in pages:
247+
yield _tabledata_list_page_to_arrow(page, column_names, arrow_types)
248+
249+
213250
def _tabledata_list_page_to_dataframe(page, column_names, dtypes):
214-
columns = collections.defaultdict(list)
215-
for row in page:
216-
for column in column_names:
217-
columns[column].append(row[column])
218-
for column in dtypes:
219-
columns[column] = pandas.Series(columns[column], dtype=dtypes[column])
251+
# Iterate over the page to force the API request to get the page data.
252+
try:
253+
next(iter(page))
254+
except StopIteration:
255+
pass
256+
257+
columns = {}
258+
for column_index, column_name in enumerate(column_names):
259+
dtype = dtypes.get(column_name)
260+
columns[column_name] = pandas.Series(page._columns[column_index], dtype=dtype)
261+
220262
return pandas.DataFrame(columns, columns=column_names)
221263

222264

@@ -350,7 +392,7 @@ def download_dataframe_bqstorage(
350392
continue
351393

352394
# Return any remaining values after the workers finished.
353-
while not worker_queue.empty():
395+
while not worker_queue.empty(): # pragma: NO COVER
354396
try:
355397
# Include a timeout because even though the queue is
356398
# non-empty, it doesn't guarantee that a subsequent call to

bigquery/google/cloud/bigquery/job.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,44 @@ def result(self, timeout=None, page_size=None, retry=DEFAULT_RETRY):
28962896
rows._preserve_order = _contains_order_by(self.query)
28972897
return rows
28982898

2899+
def to_arrow(self, progress_bar_type=None):
2900+
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
2901+
table or query.
2902+
2903+
Args:
2904+
progress_bar_type (Optional[str]):
2905+
If set, use the `tqdm <https://tqdm.github.io/>`_ library to
2906+
display a progress bar while the data downloads. Install the
2907+
``tqdm`` package to use this feature.
2908+
2909+
Possible values of ``progress_bar_type`` include:
2910+
2911+
``None``
2912+
No progress bar.
2913+
``'tqdm'``
2914+
Use the :func:`tqdm.tqdm` function to print a progress bar
2915+
to :data:`sys.stderr`.
2916+
``'tqdm_notebook'``
2917+
Use the :func:`tqdm.tqdm_notebook` function to display a
2918+
progress bar as a Jupyter notebook widget.
2919+
``'tqdm_gui'``
2920+
Use the :func:`tqdm.tqdm_gui` function to display a
2921+
progress bar as a graphical dialog box.
2922+
2923+
Returns:
2924+
pyarrow.Table
2925+
A :class:`pyarrow.Table` populated with row data and column
2926+
headers from the query results. The column headers are derived
2927+
from the destination table's schema.
2928+
2929+
Raises:
2930+
ValueError:
2931+
If the :mod:`pyarrow` library cannot be imported.
2932+
2933+
..versionadded:: 1.17.0
2934+
"""
2935+
return self.result().to_arrow(progress_bar_type=progress_bar_type)
2936+
28992937
def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):
29002938
"""Return a pandas DataFrame from a QueryJob
29012939

bigquery/google/cloud/bigquery/table.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
except ImportError: # pragma: NO COVER
3434
pandas = None
3535

36+
try:
37+
import pyarrow
38+
except ImportError: # pragma: NO COVER
39+
pyarrow = None
40+
3641
try:
3742
import tqdm
3843
except ImportError: # pragma: NO COVER
@@ -58,6 +63,10 @@
5863
"The pandas library is not installed, please install "
5964
"pandas to use the to_dataframe() function."
6065
)
66+
_NO_PYARROW_ERROR = (
67+
"The pyarrow library is not installed, please install "
68+
"pandas to use the to_arrow() function."
69+
)
6170
_NO_TQDM_ERROR = (
6271
"A progress bar was requested, but there was an error loading the tqdm "
6372
"library. Please install tqdm to use the progress bar functionality."
@@ -1394,6 +1403,72 @@ def _get_progress_bar(self, progress_bar_type):
13941403
warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
13951404
return None
13961405

1406+
def _to_arrow_iterable(self):
1407+
"""Create an iterable of arrow RecordBatches, to process the table as a stream."""
1408+
for record_batch in _pandas_helpers.download_arrow_tabledata_list(
1409+
iter(self.pages), self.schema
1410+
):
1411+
yield record_batch
1412+
1413+
def to_arrow(self, progress_bar_type=None):
1414+
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
1415+
table or query.
1416+
1417+
Args:
1418+
progress_bar_type (Optional[str]):
1419+
If set, use the `tqdm <https://tqdm.github.io/>`_ library to
1420+
display a progress bar while the data downloads. Install the
1421+
``tqdm`` package to use this feature.
1422+
1423+
Possible values of ``progress_bar_type`` include:
1424+
1425+
``None``
1426+
No progress bar.
1427+
``'tqdm'``
1428+
Use the :func:`tqdm.tqdm` function to print a progress bar
1429+
to :data:`sys.stderr`.
1430+
``'tqdm_notebook'``
1431+
Use the :func:`tqdm.tqdm_notebook` function to display a
1432+
progress bar as a Jupyter notebook widget.
1433+
``'tqdm_gui'``
1434+
Use the :func:`tqdm.tqdm_gui` function to display a
1435+
progress bar as a graphical dialog box.
1436+
1437+
Returns:
1438+
pyarrow.Table
1439+
A :class:`pyarrow.Table` populated with row data and column
1440+
headers from the query results. The column headers are derived
1441+
from the destination table's schema.
1442+
1443+
Raises:
1444+
ValueError:
1445+
If the :mod:`pyarrow` library cannot be imported.
1446+
1447+
..versionadded:: 1.17.0
1448+
"""
1449+
if pyarrow is None:
1450+
raise ValueError(_NO_PYARROW_ERROR)
1451+
1452+
progress_bar = self._get_progress_bar(progress_bar_type)
1453+
1454+
record_batches = []
1455+
for record_batch in self._to_arrow_iterable():
1456+
record_batches.append(record_batch)
1457+
1458+
if progress_bar is not None:
1459+
# In some cases, the number of total rows is not populated
1460+
# until the first page of rows is fetched. Update the
1461+
# progress bar's total to keep an accurate count.
1462+
progress_bar.total = progress_bar.total or self.total_rows
1463+
progress_bar.update(record_batch.num_rows)
1464+
1465+
if progress_bar is not None:
1466+
# Indicate that the download has finished.
1467+
progress_bar.close()
1468+
1469+
arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema)
1470+
return pyarrow.Table.from_batches(record_batches, schema=arrow_schema)
1471+
13971472
def _to_dataframe_iterable(self, bqstorage_client=None, dtypes=None):
13981473
"""Create an iterable of pandas DataFrames, to process the table as a stream.
13991474
@@ -1538,6 +1613,21 @@ class _EmptyRowIterator(object):
15381613
pages = ()
15391614
total_rows = 0
15401615

1616+
def to_arrow(self, progress_bar_type=None):
1617+
"""[Beta] Create an empty class:`pyarrow.Table`.
1618+
1619+
Args:
1620+
progress_bar_type (Optional[str]):
1621+
Ignored. Added for compatibility with RowIterator.
1622+
1623+
Returns:
1624+
pyarrow.Table:
1625+
An empty :class:`pyarrow.Table`.
1626+
"""
1627+
if pyarrow is None:
1628+
raise ValueError(_NO_PYARROW_ERROR)
1629+
return pyarrow.Table.from_arrays(())
1630+
15411631
def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):
15421632
"""Create an empty dataframe.
15431633
@@ -1734,6 +1824,25 @@ def _item_to_row(iterator, resource):
17341824
)
17351825

17361826

1827+
def _tabledata_list_page_columns(schema, response):
1828+
"""Make a generator of all the columns in a page from tabledata.list.
1829+
1830+
This enables creating a :class:`pandas.DataFrame` and other
1831+
column-oriented data structures such as :class:`pyarrow.RecordBatch`
1832+
"""
1833+
columns = []
1834+
rows = response.get("rows", [])
1835+
1836+
def get_column_data(field_index, field):
1837+
for row in rows:
1838+
yield _helpers._field_from_json(row["f"][field_index]["v"], field)
1839+
1840+
for field_index, field in enumerate(schema):
1841+
columns.append(get_column_data(field_index, field))
1842+
1843+
return columns
1844+
1845+
17371846
# pylint: disable=unused-argument
17381847
def _rows_page_start(iterator, page, response):
17391848
"""Grab total rows when :class:`~google.cloud.iterator.Page` starts.
@@ -1747,6 +1856,10 @@ def _rows_page_start(iterator, page, response):
17471856
:type response: dict
17481857
:param response: The JSON API response for a page of rows in a table.
17491858
"""
1859+
# Make a (lazy) copy of the page in column-oriented format for use in data
1860+
# science packages.
1861+
page._columns = _tabledata_list_page_columns(iterator._schema, response)
1862+
17501863
total_rows = response.get("totalRows")
17511864
if total_rows is not None:
17521865
total_rows = int(total_rows)

bigquery/samples/query_to_arrow.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def main(client):
17+
# [START bigquery_query_to_arrow]
18+
# TODO(developer): Import the client library.
19+
# from google.cloud import bigquery
20+
21+
# TODO(developer): Construct a BigQuery client object.
22+
# client = bigquery.Client()
23+
24+
sql = """
25+
WITH races AS (
26+
SELECT "800M" AS race,
27+
[STRUCT("Rudisha" as name, [23.4, 26.3, 26.4, 26.1] as splits),
28+
STRUCT("Makhloufi" as name, [24.5, 25.4, 26.6, 26.1] as splits),
29+
STRUCT("Murphy" as name, [23.9, 26.0, 27.0, 26.0] as splits),
30+
STRUCT("Bosse" as name, [23.6, 26.2, 26.5, 27.1] as splits),
31+
STRUCT("Rotich" as name, [24.7, 25.6, 26.9, 26.4] as splits),
32+
STRUCT("Lewandowski" as name, [25.0, 25.7, 26.3, 27.2] as splits),
33+
STRUCT("Kipketer" as name, [23.2, 26.1, 27.3, 29.4] as splits),
34+
STRUCT("Berian" as name, [23.7, 26.1, 27.0, 29.3] as splits)]
35+
AS participants)
36+
SELECT
37+
race,
38+
participant
39+
FROM races r
40+
CROSS JOIN UNNEST(r.participants) as participant;
41+
"""
42+
query_job = client.query(sql)
43+
arrow_table = query_job.to_arrow()
44+
45+
print(
46+
"Downloaded {} rows, {} columns.".format(
47+
arrow_table.num_rows, arrow_table.num_columns
48+
)
49+
)
50+
print("\nSchema:\n{}".format(repr(arrow_table.schema)))
51+
# [END bigquery_query_to_arrow]
52+
return arrow_table
53+
54+
55+
if __name__ == "__main__":
56+
from google.cloud import bigquery
57+
58+
main(bigquery.Client())

0 commit comments

Comments
 (0)