Skip to content

Commit

Permalink
Improve SDK UX with task creation (cvat-ai#5502)
Browse files Browse the repository at this point in the history
Extracted from cvat-ai#5083

- Added a default arg for task data uploading
- Added an option to wait for the data processing in task data uploading
- Moved data splitting by requests for TUS closer to the point of use
  • Loading branch information
zhiltsov-max authored Jan 2, 2023
1 parent 1d00e51 commit c808471
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 82 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## \[2.4.0] - Unreleased
### Added
- \[SDK\] An arg to wait for data processing in the task data uploading function
(<https://github.com/opencv/cvat/pull/5502>)
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>, <https://github.com/opencv/cvat/pull/5525>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
Expand All @@ -17,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The Docker Compose files now use the Compose Specification version
of the format. This version is supported by Docker Compose 1.27.0+
(<https://github.com/opencv/cvat/pull/5524>).
- \[SDK\] The `resource_type` args now have the default value of `local` in task creation functions.
The corresponding arguments are keyword-only now.
(<https://github.com/opencv/cvat/pull/5502>)

### Deprecated
- TDB
Expand Down
2 changes: 1 addition & 1 deletion cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def tasks_create(
self,
name: str,
labels: List[Dict[str, str]],
resource_type: ResourceType,
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = 2,
Expand Down
69 changes: 43 additions & 26 deletions cvat-sdk/cvat_sdk/core/proxies/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ class Task(

def upload_data(
self,
resource_type: ResourceType,
resources: Sequence[StrPath],
*,
resource_type: ResourceType = ResourceType.LOCAL,
pbar: Optional[ProgressReporter] = None,
params: Optional[Dict[str, Any]] = None,
wait_for_completion: bool = True,
status_check_period: Optional[int] = None,
) -> None:
"""
Add local, remote, or shared files to an existing task.
Expand Down Expand Up @@ -121,6 +123,37 @@ def upload_data(
url, list(map(Path, resources)), pbar=pbar, **data
)

if wait_for_completion:
if status_check_period is None:
status_check_period = self._client.config.status_check_period

self._client.logger.info("Awaiting for task %s creation...", self.id)
while True:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(self.id)

self._client.logger.info(
"Task %s creation status: %s (message=%s)",
self.id,
status.state.value,
status.message,
)

if (
status.state.value
== models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]
):
break
elif (
status.state.value
== models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]
):
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)

self.fetch()

def import_annotations(
self,
format_name: str,
Expand Down Expand Up @@ -296,9 +329,9 @@ class TasksRepo(
def create_from_data(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
Expand All @@ -313,9 +346,6 @@ def create_from_data(
Returns: id of the created task
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period

if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise exceptions.ApiValueError(
"Can't set labels to a task inside a project. "
Expand All @@ -326,27 +356,14 @@ def create_from_data(
task = self.create(spec=spec)
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)

task.upload_data(resource_type, resources, pbar=pbar, params=data_params)

self._client.logger.info("Awaiting for task %s creation...", task.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(task.id)

self._client.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)

if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)

status = status.state.value
task.upload_data(
resource_type=resource_type,
resources=resources,
pbar=pbar,
params=data_params,
wait_for_completion=True,
status_check_period=status_check_period,
)

if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
Expand Down
103 changes: 53 additions & 50 deletions cvat-sdk/cvat_sdk/core/uploading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import os
from contextlib import ExitStack, closing
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -206,40 +205,6 @@ def _wait_for_completion(
positive_statuses=positive_statuses,
)

def _split_files_by_requests(
self, filenames: List[Path]
) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]:
bulk_files: Dict[str, int] = {}
separate_files: Dict[str, int] = {}

# sort by size
for filename in filenames:
filename = filename.resolve()
file_size = filename.stat().st_size
if MAX_REQUEST_SIZE < file_size:
separate_files[filename] = file_size
else:
bulk_files[filename] = file_size

total_size = sum(bulk_files.values()) + sum(separate_files.values())

# group small files by requests
bulk_file_groups: List[Tuple[List[str], int]] = []
current_group_size: int = 0
current_group: List[str] = []
for filename, file_size in bulk_files.items():
if MAX_REQUEST_SIZE < current_group_size + file_size:
bulk_file_groups.append((current_group, current_group_size))
current_group_size = 0
current_group = []

current_group.append(filename)
current_group_size += file_size
if current_group:
bulk_file_groups.append((current_group, current_group_size))

return bulk_file_groups, separate_files, total_size

@staticmethod
def _make_tus_uploader(api_client: ApiClient, url: str, **kwargs):
# Add headers required by CVAT server
Expand Down Expand Up @@ -353,6 +318,10 @@ def upload_file_and_wait(


class DataUploader(Uploader):
def __init__(self, client: Client, *, max_request_size: int = MAX_REQUEST_SIZE):
super().__init__(client)
self.max_request_size = max_request_size

def upload_files(
self,
url: str,
Expand All @@ -369,22 +338,21 @@ def upload_files(
self._tus_start_upload(url)

for group, group_size in bulk_file_groups:
with ExitStack() as es:
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
os.fspath(filename),
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self._client.api_client.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api_client.get_common_headers(),
},
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
os.fspath(filename),
filename.read_bytes(),
)
response = self._client.api_client.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api_client.get_common_headers(),
},
)
expect_status(200, response)

if pbar is not None:
Expand All @@ -401,3 +369,38 @@ def upload_files(
)

self._tus_finish_upload(url, fields=kwargs)

def _split_files_by_requests(
self, filenames: List[Path]
) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]:
bulk_files: Dict[str, int] = {}
separate_files: Dict[str, int] = {}
max_request_size = self.max_request_size

# sort by size
for filename in filenames:
filename = filename.resolve()
file_size = filename.stat().st_size
if max_request_size < file_size:
separate_files[filename] = file_size
else:
bulk_files[filename] = file_size

total_size = sum(bulk_files.values()) + sum(separate_files.values())

# group small files by requests
bulk_file_groups: List[Tuple[List[str], int]] = []
current_group_size: int = 0
current_group: List[str] = []
for filename, file_size in bulk_files.items():
if max_request_size < current_group_size + file_size:
bulk_file_groups.append((current_group, current_group_size))
current_group_size = 0
current_group = []

current_group.append(filename)
current_group_size += file_size
if current_group:
bulk_file_groups.append((current_group, current_group_size))

return bulk_file_groups, separate_files, total_size
8 changes: 4 additions & 4 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def setup(
models.PatchedLabelRequest(name="car"),
],
),
ResourceType.LOCAL,
list(map(os.fspath, image_paths)),
resource_type=ResourceType.LOCAL,
resources=list(map(os.fspath, image_paths)),
data_params={"chunk_size": 3},
)

Expand Down Expand Up @@ -274,8 +274,8 @@ def setup(
project_id=self.project.id,
subset=subset,
),
ResourceType.LOCAL,
image_paths,
resource_type=ResourceType.LOCAL,
resources=image_paths,
data_params={"image_quality": 70},
)
for subset, image_paths in zip(subsets, image_paths_per_task)
Expand Down
33 changes: 32 additions & 1 deletion tests/python/sdk/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def fxt_new_task(self, fxt_image_file: Path):
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[fxt_image_file],
data_params={"image_quality": 80},
)
Expand Down Expand Up @@ -202,6 +201,38 @@ def test_can_create_task_with_git_repo(self, fxt_image_file: Path):
assert response_json["format"] == "CVAT for images 1.1"
assert response_json["lfs"] is False

def test_can_upload_data_to_empty_task(self):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)

task = self.client.tasks.create(
{
"name": f"test task",
"labels": [{"name": "car"}],
}
)

data_params = {
"image_quality": 75,
}

task_files = generate_image_files(7)
for i, f in enumerate(task_files):
fname = self.tmp_path / f.name
fname.write_bytes(f.getvalue())
task_files[i] = fname

task.upload_data(
resources=task_files,
resource_type=ResourceType.LOCAL,
params=data_params,
pbar=pbar,
)

assert task.size == 7
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == ""

def test_can_retrieve_task(self, fxt_new_task: Task):
task_id = fxt_new_task.id

Expand Down

0 comments on commit c808471

Please sign in to comment.