Martin Blanchard pushed to branch mablanch/79-cas-downloader at BuildGrid / buildgrid
Commits:
- 
9879a212
by Martin Blanchard at 2018-09-24T17:28:15Z
- 
e5d4575a
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
f6904a74
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
5d557dda
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
8dde6848
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
18739973
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
c28060e8
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
645c5a20
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
ea0a05dc
by Martin Blanchard at 2018-09-24T17:28:20Z
- 
34995a8c
by Martin Blanchard at 2018-09-24T17:28:20Z
11 changed files:
- buildgrid/_app/bots/buildbox.py
- buildgrid/_app/bots/temp_directory.py
- buildgrid/_app/commands/cmd_cas.py
- buildgrid/_app/commands/cmd_execute.py
- buildgrid/client/cas.py
- buildgrid/server/cas/service.py
- buildgrid/server/cas/storage/remote.py
- buildgrid/utils.py
- docs/source/reference_cli.rst
- tests/cas/test_client.py
- tests/utils/cas.py
Changes:
| ... | ... | @@ -19,32 +19,34 @@ import tempfile | 
| 19 | 19 |  | 
| 20 | 20 |  from google.protobuf import any_pb2
 | 
| 21 | 21 |  | 
| 22 | -from buildgrid.settings import HASH_LENGTH
 | |
| 23 | -from buildgrid.client.cas import upload
 | |
| 24 | -from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 | |
| 25 | -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
 | |
| 22 | +from buildgrid.client.cas import download, upload
 | |
| 26 | 23 |  from buildgrid._exceptions import BotError
 | 
| 27 | -from buildgrid.utils import read_file, write_file, parse_to_pb2_from_fetch
 | |
| 24 | +from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 | |
| 25 | +from buildgrid.settings import HASH_LENGTH
 | |
| 26 | +from buildgrid.utils import read_file, write_file
 | |
| 28 | 27 |  | 
| 29 | 28 |  | 
| 30 | 29 |  def work_buildbox(context, lease):
 | 
| 31 | 30 |      """Executes a lease for a build action, using buildbox.
 | 
| 32 | 31 |      """
 | 
| 33 | 32 |  | 
| 34 | -    stub_bytestream = bytestream_pb2_grpc.ByteStreamStub(context.cas_channel)
 | |
| 35 | 33 |      local_cas_directory = context.local_cas
 | 
| 34 | +    # instance_name = context.parent
 | |
| 36 | 35 |      logger = context.logger
 | 
| 37 | 36 |  | 
| 38 | 37 |      action_digest = remote_execution_pb2.Digest()
 | 
| 39 | 38 |      lease.payload.Unpack(action_digest)
 | 
| 40 | 39 |  | 
| 41 | -    action = parse_to_pb2_from_fetch(remote_execution_pb2.Action(),
 | |
| 42 | -                                     stub_bytestream, action_digest)
 | |
| 40 | +    with download(context.cas_channel) as downloader:
 | |
| 41 | +        action = downloader.get_message(action_digest,
 | |
| 42 | +                                        remote_execution_pb2.Action())
 | |
| 43 | 43 |  | 
| 44 | -    command = parse_to_pb2_from_fetch(remote_execution_pb2.Command(),
 | |
| 45 | -                                      stub_bytestream, action.command_digest)
 | |
| 44 | +        assert action.command_digest.hash
 | |
| 46 | 45 |  | 
| 47 | -    environment = dict()
 | |
| 46 | +        command = downloader.get_message(action.command_digest,
 | |
| 47 | +                                         remote_execution_pb2.Command())
 | |
| 48 | + | |
| 49 | +    environment = {}
 | |
| 48 | 50 |      for variable in command.environment_variables:
 | 
| 49 | 51 |          if variable.name not in ['PWD']:
 | 
| 50 | 52 |              environment[variable.name] = variable.value
 | 
| ... | ... | @@ -116,10 +118,11 @@ def work_buildbox(context, lease): | 
| 116 | 118 |  | 
| 117 | 119 |              # TODO: Have BuildBox helping us creating the Tree instance here
 | 
| 118 | 120 |              # See https://gitlab.com/BuildStream/buildbox/issues/7 for details
 | 
| 119 | -            output_tree = _cas_tree_maker(stub_bytestream, output_digest)
 | |
| 121 | +            with download(context.cas_channel) as downloader:
 | |
| 122 | +                output_tree = _cas_tree_maker(downloader, output_digest)
 | |
| 120 | 123 |  | 
| 121 | -            with upload(context.cas_channel) as cas:
 | |
| 122 | -                output_tree_digest = cas.put_message(output_tree)
 | |
| 124 | +            with upload(context.cas_channel) as uploader:
 | |
| 125 | +                output_tree_digest = uploader.put_message(output_tree)
 | |
| 123 | 126 |  | 
| 124 | 127 |              output_directory = remote_execution_pb2.OutputDirectory()
 | 
| 125 | 128 |              output_directory.tree_digest.CopyFrom(output_tree_digest)
 | 
| ... | ... | @@ -135,24 +138,28 @@ def work_buildbox(context, lease): | 
| 135 | 138 |      return lease
 | 
| 136 | 139 |  | 
| 137 | 140 |  | 
| 138 | -def _cas_tree_maker(stub_bytestream, directory_digest):
 | |
| 141 | +def _cas_tree_maker(cas, directory_digest):
 | |
| 139 | 142 |      # Generates and stores a Tree for a given Directory. This is very inefficient
 | 
| 140 | 143 |      # and only temporary. See https://gitlab.com/BuildStream/buildbox/issues/7.
 | 
| 141 | 144 |      output_tree = remote_execution_pb2.Tree()
 | 
| 142 | 145 |  | 
| 143 | -    def list_directories(parent_directory):
 | |
| 144 | -        directory_list = list()
 | |
| 146 | +    def __cas_tree_maker(cas, parent_directory):
 | |
| 147 | +        digests, directories = [], []
 | |
| 145 | 148 |          for directory_node in parent_directory.directories:
 | 
| 146 | -            directory = parse_to_pb2_from_fetch(remote_execution_pb2.Directory(),
 | |
| 147 | -                                                stub_bytestream, directory_node.digest)
 | |
| 148 | -            directory_list.extend(list_directories(directory))
 | |
| 149 | -            directory_list.append(directory)
 | |
| 149 | +            directories.append(remote_execution_pb2.Directory())
 | |
| 150 | +            digests.append(directory_node.digest)
 | |
| 151 | + | |
| 152 | +        cas.get_messages(digests, directories)
 | |
| 153 | + | |
| 154 | +        for directory in directories[:]:
 | |
| 155 | +            directories.extend(__cas_tree_maker(cas, directory))
 | |
| 156 | + | |
| 157 | +        return directories
 | |
| 150 | 158 |  | 
| 151 | -        return directory_list
 | |
| 159 | +    root_directory = cas.get_message(directory_digest,
 | |
| 160 | +                                     remote_execution_pb2.Directory())
 | |
| 152 | 161 |  | 
| 153 | -    root_directory = parse_to_pb2_from_fetch(remote_execution_pb2.Directory(),
 | |
| 154 | -                                             stub_bytestream, directory_digest)
 | |
| 155 | -    output_tree.children.extend(list_directories(root_directory))
 | |
| 162 | +    output_tree.children.extend(__cas_tree_maker(cas, root_directory))
 | |
| 156 | 163 |      output_tree.root.CopyFrom(root_directory)
 | 
| 157 | 164 |  | 
| 158 | 165 |      return output_tree | 
| ... | ... | @@ -19,10 +19,8 @@ import tempfile | 
| 19 | 19 |  | 
| 20 | 20 |  from google.protobuf import any_pb2
 | 
| 21 | 21 |  | 
| 22 | -from buildgrid.client.cas import upload
 | |
| 22 | +from buildgrid.client.cas import download, upload
 | |
| 23 | 23 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 | 
| 24 | -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
 | |
| 25 | -from buildgrid.utils import write_fetch_directory, parse_to_pb2_from_fetch
 | |
| 26 | 24 |  from buildgrid.utils import output_file_maker, output_directory_maker
 | 
| 27 | 25 |  | 
| 28 | 26 |  | 
| ... | ... | @@ -30,29 +28,30 @@ def work_temp_directory(context, lease): | 
| 30 | 28 |      """Executes a lease for a build action, using host tools.
 | 
| 31 | 29 |      """
 | 
| 32 | 30 |  | 
| 33 | -    stub_bytestream = bytestream_pb2_grpc.ByteStreamStub(context.cas_channel)
 | |
| 34 | 31 |      instance_name = context.parent
 | 
| 35 | 32 |      logger = context.logger
 | 
| 36 | 33 |  | 
| 37 | 34 |      action_digest = remote_execution_pb2.Digest()
 | 
| 38 | 35 |      lease.payload.Unpack(action_digest)
 | 
| 39 | 36 |  | 
| 40 | -    action = parse_to_pb2_from_fetch(remote_execution_pb2.Action(),
 | |
| 41 | -                                     stub_bytestream, action_digest, instance_name)
 | |
| 42 | - | |
| 43 | 37 |      with tempfile.TemporaryDirectory() as temp_directory:
 | 
| 44 | -        command = parse_to_pb2_from_fetch(remote_execution_pb2.Command(),
 | |
| 45 | -                                          stub_bytestream, action.command_digest, instance_name)
 | |
| 38 | +        with download(context.cas_channel, instance=instance_name) as downloader:
 | |
| 39 | +            action = downloader.get_message(action_digest,
 | |
| 40 | +                                            remote_execution_pb2.Action())
 | |
| 41 | + | |
| 42 | +            assert action.command_digest.hash
 | |
| 43 | + | |
| 44 | +            command = downloader.get_message(action.command_digest,
 | |
| 45 | +                                             remote_execution_pb2.Command())
 | |
| 46 | 46 |  | 
| 47 | -        write_fetch_directory(temp_directory, stub_bytestream,
 | |
| 48 | -                              action.input_root_digest, instance_name)
 | |
| 47 | +            downloader.download_directory(action.input_root_digest, temp_directory)
 | |
| 49 | 48 |  | 
| 50 | 49 |          environment = os.environ.copy()
 | 
| 51 | 50 |          for variable in command.environment_variables:
 | 
| 52 | 51 |              if variable.name not in ['PATH', 'PWD']:
 | 
| 53 | 52 |                  environment[variable.name] = variable.value
 | 
| 54 | 53 |  | 
| 55 | -        command_line = list()
 | |
| 54 | +        command_line = []
 | |
| 56 | 55 |          for argument in command.arguments:
 | 
| 57 | 56 |              command_line.append(argument.strip())
 | 
| 58 | 57 |  | 
| ... | ... | @@ -28,7 +28,7 @@ from urllib.parse import urlparse | 
| 28 | 28 |  import click
 | 
| 29 | 29 |  import grpc
 | 
| 30 | 30 |  | 
| 31 | -from buildgrid.client.cas import upload
 | |
| 31 | +from buildgrid.client.cas import download, upload
 | |
| 32 | 32 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 | 
| 33 | 33 |  from buildgrid.utils import merkle_tree_maker
 | 
| 34 | 34 |  | 
| ... | ... | @@ -81,18 +81,18 @@ def upload_dummy(context): | 
| 81 | 81 |          click.echo("Error: Failed pushing empty message.", err=True)
 | 
| 82 | 82 |  | 
| 83 | 83 |  | 
| 84 | -@cli.command('upload-files', short_help="Upload files to the CAS server.")
 | |
| 85 | -@click.argument('files', nargs=-1, type=click.Path(exists=True, dir_okay=False), required=True)
 | |
| 84 | +@cli.command('upload-file', short_help="Upload files to the CAS server.")
 | |
| 85 | +@click.argument('file_path', nargs=-1, type=click.Path(exists=True, dir_okay=False), required=True)
 | |
| 86 | 86 |  @pass_context
 | 
| 87 | -def upload_files(context, files):
 | |
| 87 | +def upload_file(context, file_path):
 | |
| 88 | 88 |      sent_digests, files_map = [], {}
 | 
| 89 | 89 |      with upload(context.channel, instance=context.instance_name) as uploader:
 | 
| 90 | -        for file_path in files:
 | |
| 91 | -            context.logger.debug("Queueing {}".format(file_path))
 | |
| 90 | +        for path in file_path:
 | |
| 91 | +            context.logger.debug("Queueing {}".format(path))
 | |
| 92 | 92 |  | 
| 93 | -            file_digest = uploader.upload_file(file_path, queue=True)
 | |
| 93 | +            file_digest = uploader.upload_file(path, queue=True)
 | |
| 94 | 94 |  | 
| 95 | -            files_map[file_digest.hash] = file_path
 | |
| 95 | +            files_map[file_digest.hash] = path
 | |
| 96 | 96 |              sent_digests.append(file_digest)
 | 
| 97 | 97 |  | 
| 98 | 98 |      for file_digest in sent_digests:
 | 
| ... | ... | @@ -107,12 +107,12 @@ def upload_files(context, files): | 
| 107 | 107 |  | 
| 108 | 108 |  | 
| 109 | 109 |  @cli.command('upload-dir', short_help="Upload a directory to the CAS server.")
 | 
| 110 | -@click.argument('directory', nargs=1, type=click.Path(exists=True, file_okay=False), required=True)
 | |
| 110 | +@click.argument('directory-path', nargs=1, type=click.Path(exists=True, file_okay=False), required=True)
 | |
| 111 | 111 |  @pass_context
 | 
| 112 | -def upload_dir(context, directory):
 | |
| 112 | +def upload_directory(context, directory_path):
 | |
| 113 | 113 |      sent_digests, nodes_map = [], {}
 | 
| 114 | 114 |      with upload(context.channel, instance=context.instance_name) as uploader:
 | 
| 115 | -        for node, blob, path in merkle_tree_maker(directory):
 | |
| 115 | +        for node, blob, path in merkle_tree_maker(directory_path):
 | |
| 116 | 116 |              context.logger.debug("Queueing {}".format(path))
 | 
| 117 | 117 |  | 
| 118 | 118 |              node_digest = uploader.put_blob(blob, digest=node.digest, queue=True)
 | 
| ... | ... | @@ -123,9 +123,49 @@ def upload_dir(context, directory): | 
| 123 | 123 |      for node_digest in sent_digests:
 | 
| 124 | 124 |          node_path = nodes_map[node_digest.hash]
 | 
| 125 | 125 |          if os.path.isabs(node_path):
 | 
| 126 | -            node_path = os.path.relpath(node_path, start=directory)
 | |
| 126 | +            node_path = os.path.relpath(node_path, start=directory_path)
 | |
| 127 | 127 |          if node_digest.ByteSize():
 | 
| 128 | 128 |              click.echo('Success: Pushed "{}" with digest "{}/{}"'
 | 
| 129 | 129 |                         .format(node_path, node_digest.hash, node_digest.size_bytes))
 | 
| 130 | 130 |          else:
 | 
| 131 | 131 |              click.echo('Error: Failed to push "{}"'.format(node_path), err=True)
 | 
| 132 | + | |
| 133 | + | |
| 134 | +def __create_digest(digest_string):
 | |
| 135 | +    digest_hash, digest_size = digest_string.split('/')
 | |
| 136 | + | |
| 137 | +    digest = remote_execution_pb2.Digest()
 | |
| 138 | +    digest.hash = digest_hash
 | |
| 139 | +    digest.size_bytes = int(digest_size)
 | |
| 140 | + | |
| 141 | +    return digest
 | |
| 142 | + | |
| 143 | + | |
| 144 | +@cli.command('download-file', short_help="Download a file from the CAS server.")
 | |
| 145 | +@click.argument('digest-string', nargs=1, type=click.STRING, required=True)
 | |
| 146 | +@click.argument('file-path', nargs=1, type=click.Path(exists=False), required=True)
 | |
| 147 | +@pass_context
 | |
| 148 | +def download_file(context, digest_string, file_path):
 | |
| 149 | +    if os.path.exists(file_path):
 | |
| 150 | +        click.echo('Error: Invalid value for "file-path": ' +
 | |
| 151 | +                   'Path "{}" already exists.'.format(file_path), err=True)
 | |
| 152 | + | |
| 153 | +    digest = __create_digest(digest_string)
 | |
| 154 | + | |
| 155 | +    with download(context.channel, instance=context.instance_name) as downloader:
 | |
| 156 | +        downloader.download_file(digest, file_path)
 | |
| 157 | + | |
| 158 | + | |
| 159 | +@cli.command('download-dir', short_help="Download a directory from the CAS server.")
 | |
| 160 | +@click.argument('digest-string', nargs=1, type=click.STRING, required=True)
 | |
| 161 | +@click.argument('directory-path', nargs=1, type=click.Path(exists=False), required=True)
 | |
| 162 | +@pass_context
 | |
| 163 | +def download_directory(context, digest_string, directory_path):
 | |
| 164 | +    if os.path.exists(directory_path):
 | |
| 165 | +        click.echo('Error: Invalid value for "directory-path": ' +
 | |
| 166 | +                   'Path "{}" already exists.'.format(directory_path), err=True)
 | |
| 167 | + | |
| 168 | +    digest = __create_digest(digest_string)
 | |
| 169 | + | |
| 170 | +    with download(context.channel, instance=context.instance_name) as downloader:
 | |
| 171 | +        downloader.download_directory(digest, directory_path) | 
| ... | ... | @@ -20,7 +20,6 @@ Execute command | 
| 20 | 20 |  Request work to be executed and monitor status of jobs.
 | 
| 21 | 21 |  """
 | 
| 22 | 22 |  | 
| 23 | -import errno
 | |
| 24 | 23 |  import logging
 | 
| 25 | 24 |  import os
 | 
| 26 | 25 |  import stat
 | 
| ... | ... | @@ -30,10 +29,9 @@ from urllib.parse import urlparse | 
| 30 | 29 |  import click
 | 
| 31 | 30 |  import grpc
 | 
| 32 | 31 |  | 
| 33 | -from buildgrid.client.cas import upload
 | |
| 32 | +from buildgrid.client.cas import download, upload
 | |
| 34 | 33 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
 | 
| 35 | -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
 | |
| 36 | -from buildgrid.utils import create_digest, write_fetch_blob
 | |
| 34 | +from buildgrid.utils import create_digest
 | |
| 37 | 35 |  | 
| 38 | 36 |  from ..cli import pass_context
 | 
| 39 | 37 |  | 
| ... | ... | @@ -154,8 +152,6 @@ def run_command(context, input_root, commands, output_file, output_directory): | 
| 154 | 152 |                                                    skip_cache_lookup=True)
 | 
| 155 | 153 |      response = stub.Execute(request)
 | 
| 156 | 154 |  | 
| 157 | -    stub = bytestream_pb2_grpc.ByteStreamStub(context.channel)
 | |
| 158 | - | |
| 159 | 155 |      stream = None
 | 
| 160 | 156 |      for stream in response:
 | 
| 161 | 157 |          context.logger.info(stream)
 | 
| ... | ... | @@ -163,21 +159,16 @@ def run_command(context, input_root, commands, output_file, output_directory): | 
| 163 | 159 |      execute_response = remote_execution_pb2.ExecuteResponse()
 | 
| 164 | 160 |      stream.response.Unpack(execute_response)
 | 
| 165 | 161 |  | 
| 166 | -    for output_file_response in execute_response.result.output_files:
 | |
| 167 | -        path = os.path.join(output_directory, output_file_response.path)
 | |
| 168 | - | |
| 169 | -        if not os.path.exists(os.path.dirname(path)):
 | |
| 162 | +    with download(context.channel, instance=context.instance_name) as downloader:
 | |
| 170 | 163 |  | 
| 171 | -            try:
 | |
| 172 | -                os.makedirs(os.path.dirname(path))
 | |
| 164 | +        for output_file_response in execute_response.result.output_files:
 | |
| 165 | +            path = os.path.join(output_directory, output_file_response.path)
 | |
| 173 | 166 |  | 
| 174 | -            except OSError as exc:
 | |
| 175 | -                if exc.errno != errno.EEXIST:
 | |
| 176 | -                    raise
 | |
| 167 | +            if not os.path.exists(os.path.dirname(path)):
 | |
| 168 | +                os.makedirs(os.path.dirname(path), exist_ok=True)
 | |
| 177 | 169 |  | 
| 178 | -        with open(path, 'wb+') as f:
 | |
| 179 | -            write_fetch_blob(f, stub, output_file_response.digest, context.instance_name)
 | |
| 170 | +            downloader.download_file(output_file_response.digest, path)
 | |
| 180 | 171 |  | 
| 181 | -        if output_file_response.path in output_executeables:
 | |
| 182 | -            st = os.stat(path)
 | |
| 183 | -            os.chmod(path, st.st_mode | stat.S_IXUSR) | |
| 172 | +            if output_file_response.path in output_executeables:
 | |
| 173 | +                st = os.stat(path)
 | |
| 174 | +                os.chmod(path, st.st_mode | stat.S_IXUSR) | 
| ... | ... | @@ -19,6 +19,7 @@ import os | 
| 19 | 19 |  | 
| 20 | 20 |  import grpc
 | 
| 21 | 21 |  | 
| 22 | +from buildgrid._exceptions import NotFoundError
 | |
| 22 | 23 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
 | 
| 23 | 24 |  from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
 | 
| 24 | 25 |  from buildgrid._protos.google.rpc import code_pb2
 | 
| ... | ... | @@ -26,6 +27,16 @@ from buildgrid.settings import HASH | 
| 26 | 27 |  from buildgrid.utils import merkle_tree_maker
 | 
| 27 | 28 |  | 
| 28 | 29 |  | 
| 30 | +# Maximum size for a queueable file:
 | |
| 31 | +FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
 | |
| 32 | + | |
| 33 | +# Maximum size for a single gRPC request:
 | |
| 34 | +MAX_REQUEST_SIZE = 2 * 1024 * 1024
 | |
| 35 | + | |
| 36 | +# Maximum number of elements per gRPC request:
 | |
| 37 | +MAX_REQUEST_COUNT = 500
 | |
| 38 | + | |
| 39 | + | |
| 29 | 40 |  class _CallCache:
 | 
| 30 | 41 |      """Per remote grpc.StatusCode.UNIMPLEMENTED call cache."""
 | 
| 31 | 42 |      __calls = {}
 | 
| ... | ... | @@ -43,6 +54,401 @@ class _CallCache: | 
| 43 | 54 |          return name in cls.__calls[channel]
 | 
| 44 | 55 |  | 
| 45 | 56 |  | 
| 57 | +@contextmanager
 | |
| 58 | +def download(channel, instance=None, u_uid=None):
 | |
| 59 | +    """Context manager generator for the :class:`Downloader` class."""
 | |
| 60 | +    downloader = Downloader(channel, instance=instance)
 | |
| 61 | +    try:
 | |
| 62 | +        yield downloader
 | |
| 63 | +    finally:
 | |
| 64 | +        downloader.close()
 | |
| 65 | + | |
| 66 | + | |
| 67 | +class Downloader:
 | |
| 68 | +    """Remote CAS files, directories and messages download helper.
 | |
| 69 | + | |
| 70 | +    The :class:`Downloader` class comes with a generator factory function that
 | |
| 71 | +    can be used together with the `with` statement for context management::
 | |
| 72 | + | |
| 73 | +        from buildgrid.client.cas import download
 | |
| 74 | + | |
| 75 | +        with download(channel, instance='build') as downloader:
 | |
| 76 | +            downloader.get_message(message_digest)
 | |
| 77 | +    """
 | |
| 78 | + | |
| 79 | +    def __init__(self, channel, instance=None):
 | |
| 80 | +        """Initializes a new :class:`Downloader` instance.
 | |
| 81 | + | |
| 82 | +        Args:
 | |
| 83 | +            channel (grpc.Channel): A gRPC channel to the CAS endpoint.
 | |
| 84 | +            instance (str, optional): the targeted instance's name.
 | |
| 85 | +        """
 | |
| 86 | +        self.channel = channel
 | |
| 87 | + | |
| 88 | +        self.instance_name = instance
 | |
| 89 | + | |
| 90 | +        self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel)
 | |
| 91 | +        self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
 | |
| 92 | + | |
| 93 | +        self.__file_requests = {}
 | |
| 94 | +        self.__file_request_count = 0
 | |
| 95 | +        self.__file_request_size = 0
 | |
| 96 | +        self.__file_response_size = 0
 | |
| 97 | + | |
| 98 | +    # --- Public API ---
 | |
| 99 | + | |
| 100 | +    def get_blob(self, digest):
 | |
| 101 | +        """Retrieves a blob from the remote CAS server.
 | |
| 102 | + | |
| 103 | +        Args:
 | |
| 104 | +            digest (:obj:`Digest`): the blob's digest to fetch.
 | |
| 105 | + | |
| 106 | +        Returns:
 | |
| 107 | +            bytearray: the fetched blob data or None if not found.
 | |
| 108 | +        """
 | |
| 109 | +        try:
 | |
| 110 | +            blob = self._fetch_blob(digest)
 | |
| 111 | +        except NotFoundError:
 | |
| 112 | +            return None
 | |
| 113 | + | |
| 114 | +        return blob
 | |
| 115 | + | |
| 116 | +    def get_blobs(self, digests):
 | |
| 117 | +        """Retrieves a list of blobs from the remote CAS server.
 | |
| 118 | + | |
| 119 | +        Args:
 | |
| 120 | +            digests (list): list of :obj:`Digest`s for the blobs to fetch.
 | |
| 121 | + | |
| 122 | +        Returns:
 | |
| 123 | +            list: the fetched blob data list.
 | |
| 124 | +        """
 | |
| 125 | +        return self._fetch_blob_batch(digests)
 | |
| 126 | + | |
| 127 | +    def get_message(self, digest, message):
 | |
| 128 | +        """Retrieves a :obj:`Message` from the remote CAS server.
 | |
| 129 | + | |
| 130 | +        Args:
 | |
| 131 | +            digest (:obj:`Digest`): the message's digest to fetch.
 | |
| 132 | +            message (:obj:`Message`): an empty message to fill.
 | |
| 133 | + | |
| 134 | +        Returns:
 | |
| 135 | +            :obj:`Message`: `message` filled or emptied if not found.
 | |
| 136 | +        """
 | |
| 137 | +        try:
 | |
| 138 | +            message_blob = self._fetch_blob(digest)
 | |
| 139 | +        except NotFoundError:
 | |
| 140 | +            message_blob = None
 | |
| 141 | + | |
| 142 | +        if message_blob is not None:
 | |
| 143 | +            message.ParseFromString(message_blob)
 | |
| 144 | +        else:
 | |
| 145 | +            message.Clear()
 | |
| 146 | + | |
| 147 | +        return message
 | |
| 148 | + | |
| 149 | +    def get_messages(self, digests, messages):
 | |
| 150 | +        """Retrieves a list of :obj:`Message`s from the remote CAS server.
 | |
| 151 | + | |
| 152 | +        Note:
 | |
| 153 | +            The `digests` and `messages` list **must** contain the same number
 | |
| 154 | +            of elements.
 | |
| 155 | + | |
| 156 | +        Args:
 | |
| 157 | +            digests (list):  list of :obj:`Digest`s for the messages to fetch.
 | |
| 158 | +            messages (list): list of empty :obj:`Message`s to fill.
 | |
| 159 | + | |
| 160 | +        Returns:
 | |
| 161 | +            list: the fetched and filled message list.
 | |
| 162 | +        """
 | |
| 163 | +        assert len(digests) == len(messages)
 | |
| 164 | + | |
| 165 | +        message_blobs = self._fetch_blob_batch(digests)
 | |
| 166 | + | |
| 167 | +        assert len(message_blobs) == len(messages)
 | |
| 168 | + | |
| 169 | +        for message, message_blob in zip(messages, message_blobs):
 | |
| 170 | +            message.ParseFromString(message_blob)
 | |
| 171 | + | |
| 172 | +        return messages
 | |
| 173 | + | |
| 174 | +    def download_file(self, digest, file_path, queue=True):
 | |
| 175 | +        """Retrieves a file from the remote CAS server.
 | |
| 176 | + | |
| 177 | +        If queuing is allowed (`queue=True`), the download request **may** be
 | |
| 178 | +        defer. An explicit call to :func:`~flush` can force the request to be
 | |
| 179 | +        send immediately (along with the rest of the queued batch).
 | |
| 180 | + | |
| 181 | +        Args:
 | |
| 182 | +            digest (:obj:`Digest`): the file's digest to fetch.
 | |
| 183 | +            file_path (str): absolute or relative path to the local file to write.
 | |
| 184 | +            queue (bool, optional): whether or not the download request may be
 | |
| 185 | +                queued and submitted as part of a batch upload request. Defaults
 | |
| 186 | +                to True.
 | |
| 187 | + | |
| 188 | +        Raises:
 | |
| 189 | +            NotFoundError: if `digest` is not present in the remote CAS server.
 | |
| 190 | +            OSError: if `file_path` does not exist or is not readable.
 | |
| 191 | +        """
 | |
| 192 | +        if not os.path.isabs(file_path):
 | |
| 193 | +            file_path = os.path.abspath(file_path)
 | |
| 194 | + | |
| 195 | +        if not queue or digest.size_bytes > FILE_SIZE_THRESHOLD:
 | |
| 196 | +            self._fetch_file(digest, file_path)
 | |
| 197 | +        else:
 | |
| 198 | +            self._queue_file(digest, file_path)
 | |
| 199 | + | |
| 200 | +    def download_directory(self, digest, directory_path):
 | |
| 201 | +        """Retrieves a :obj:`Directory` from the remote CAS server.
 | |
| 202 | + | |
| 203 | +        Args:
 | |
| 204 | +            digest (:obj:`Digest`): the directory's digest to fetch.
 | |
| 205 | + | |
| 206 | +        Raises:
 | |
| 207 | +            NotFoundError: if `digest` is not present in the remote CAS server.
 | |
| 208 | +            FileExistsError: if `directory_path` already contains parts of their
 | |
| 209 | +                fetched directory's content.
 | |
| 210 | +        """
 | |
| 211 | +        if not os.path.isabs(directory_path):
 | |
| 212 | +            directory_path = os.path.abspath(directory_path)
 | |
| 213 | + | |
| 214 | +        # We want to start fresh here, the rest is very synchronous...
 | |
| 215 | +        self.flush()
 | |
| 216 | + | |
| 217 | +        self._fetch_directory(digest, directory_path)
 | |
| 218 | + | |
| 219 | +    def flush(self):
 | |
| 220 | +        """Ensures any queued request gets sent."""
 | |
| 221 | +        if self.__file_requests:
 | |
| 222 | +            self._fetch_file_batch(self.__file_requests)
 | |
| 223 | + | |
| 224 | +            self.__file_requests.clear()
 | |
| 225 | +            self.__file_request_count = 0
 | |
| 226 | +            self.__file_request_size = 0
 | |
| 227 | +            self.__file_response_size = 0
 | |
| 228 | + | |
| 229 | +    def close(self):
 | |
| 230 | +        """Closes the underlying connection stubs.
 | |
| 231 | + | |
| 232 | +        Note:
 | |
| 233 | +            This will always send pending requests before closing connections,
 | |
| 234 | +            if any.
 | |
| 235 | +        """
 | |
| 236 | +        self.flush()
 | |
| 237 | + | |
| 238 | +        self.__bytestream_stub = None
 | |
| 239 | +        self.__cas_stub = None
 | |
| 240 | + | |
| 241 | +    # --- Private API ---
 | |
| 242 | + | |
| 243 | +    def _fetch_blob(self, digest):
 | |
| 244 | +        """Fetches a blob using ByteStream.Read()"""
 | |
| 245 | +        read_blob = bytearray()
 | |
| 246 | + | |
| 247 | +        if self.instance_name is not None:
 | |
| 248 | +            resource_name = '/'.join([self.instance_name, 'blobs',
 | |
| 249 | +                                      digest.hash, str(digest.size_bytes)])
 | |
| 250 | +        else:
 | |
| 251 | +            resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
 | |
| 252 | + | |
| 253 | +        read_request = bytestream_pb2.ReadRequest()
 | |
| 254 | +        read_request.resource_name = resource_name
 | |
| 255 | +        read_request.read_offset = 0
 | |
| 256 | + | |
| 257 | +        try:
 | |
| 258 | +            # TODO: Handle connection loss/recovery
 | |
| 259 | +            for read_response in self.__bytestream_stub.Read(read_request):
 | |
| 260 | +                read_blob += read_response.data
 | |
| 261 | + | |
| 262 | +            assert len(read_blob) == digest.size_bytes
 | |
| 263 | + | |
| 264 | +        except grpc.RpcError as e:
 | |
| 265 | +            status_code = e.code()
 | |
| 266 | +            if status_code == grpc.StatusCode.NOT_FOUND:
 | |
| 267 | +                raise NotFoundError("Requested data does not exist on the remote.")
 | |
| 268 | + | |
| 269 | +            else:
 | |
| 270 | +                assert False
 | |
| 271 | + | |
| 272 | +        return read_blob
 | |
| 273 | + | |
| 274 | +    def _fetch_blob_batch(self, digests):
 | |
| 275 | +        """Fetches blobs using ContentAddressableStorage.BatchReadBlobs()"""
 | |
| 276 | +        batch_fetched = False
 | |
| 277 | +        read_blobs = []
 | |
| 278 | + | |
| 279 | +        # First, try BatchReadBlobs(), if not already known not being implemented:
 | |
| 280 | +        if not _CallCache.unimplemented(self.channel, 'BatchReadBlobs'):
 | |
| 281 | +            batch_request = remote_execution_pb2.BatchReadBlobsRequest()
 | |
| 282 | +            batch_request.digests.extend(digests)
 | |
| 283 | +            if self.instance_name is not None:
 | |
| 284 | +                batch_request.instance_name = self.instance_name
 | |
| 285 | + | |
| 286 | +            try:
 | |
| 287 | +                batch_response = self.__cas_stub.BatchReadBlobs(batch_request)
 | |
| 288 | +                for response in batch_response.responses:
 | |
| 289 | +                    assert response.digest.hash in digests
 | |
| 290 | + | |
| 291 | +                    read_blobs.append(response.data)
 | |
| 292 | + | |
| 293 | +                    if response.status.code != code_pb2.OK:
 | |
| 294 | +                        assert False
 | |
| 295 | + | |
| 296 | +                batch_fetched = True
 | |
| 297 | + | |
| 298 | +            except grpc.RpcError as e:
 | |
| 299 | +                status_code = e.code()
 | |
| 300 | +                if status_code == grpc.StatusCode.UNIMPLEMENTED:
 | |
| 301 | +                    _CallCache.mark_unimplemented(self.channel, 'BatchReadBlobs')
 | |
| 302 | + | |
| 303 | +                elif status_code == grpc.StatusCode.INVALID_ARGUMENT:
 | |
| 304 | +                    read_blobs.clear()
 | |
| 305 | +                    batch_fetched = False
 | |
| 306 | + | |
| 307 | +                else:
 | |
| 308 | +                    assert False
 | |
| 309 | + | |
| 310 | +        # Fallback to Read() if no BatchReadBlobs():
 | |
| 311 | +        if not batch_fetched:
 | |
| 312 | +            for digest in digests:
 | |
| 313 | +                read_blobs.append(self._fetch_blob(digest))
 | |
| 314 | + | |
| 315 | +        return read_blobs
 | |
| 316 | + | |
| 317 | +    def _fetch_file(self, digest, file_path):
 | |
| 318 | +        """Fetches a file using ByteStream.Read()"""
 | |
| 319 | +        if self.instance_name is not None:
 | |
| 320 | +            resource_name = '/'.join([self.instance_name, 'blobs',
 | |
| 321 | +                                      digest.hash, str(digest.size_bytes)])
 | |
| 322 | +        else:
 | |
| 323 | +            resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
 | |
| 324 | + | |
| 325 | +        read_request = bytestream_pb2.ReadRequest()
 | |
| 326 | +        read_request.resource_name = resource_name
 | |
| 327 | +        read_request.read_offset = 0
 | |
| 328 | + | |
| 329 | +        os.makedirs(os.path.dirname(file_path), exist_ok=True)
 | |
| 330 | + | |
| 331 | +        with open(file_path, 'wb') as byte_file:
 | |
| 332 | +            # TODO: Handle connection loss/recovery
 | |
| 333 | +            for read_response in self.__bytestream_stub.Read(read_request):
 | |
| 334 | +                byte_file.write(read_response.data)
 | |
| 335 | + | |
| 336 | +            assert byte_file.tell() == digest.size_bytes
 | |
| 337 | + | |
| 338 | +    def _queue_file(self, digest, file_path):
 | |
| 339 | +        """Queues a file for later batch download"""
 | |
| 340 | +        if self.__file_request_size + digest.ByteSize() > MAX_REQUEST_SIZE:
 | |
| 341 | +            self.flush()
 | |
| 342 | +        elif self.__file_response_size + digest.size_bytes > MAX_REQUEST_SIZE:
 | |
| 343 | +            self.flush()
 | |
| 344 | +        elif self.__file_request_count >= MAX_REQUEST_COUNT:
 | |
| 345 | +            self.flush()
 | |
| 346 | + | |
| 347 | +        self.__file_requests[digest.hash] = (digest, file_path)
 | |
| 348 | +        self.__file_request_count += 1
 | |
| 349 | +        self.__file_request_size += digest.ByteSize()
 | |
| 350 | +        self.__file_response_size += digest.size_bytes
 | |
| 351 | + | |
| 352 | +    def _fetch_file_batch(self, batch):
 | |
| 353 | +        """Sends queued data using ContentAddressableStorage.BatchReadBlobs()"""
 | |
| 354 | +        batch_digests = [digest for digest, _ in batch.values()]
 | |
| 355 | +        batch_blobs = self._fetch_blob_batch(batch_digests)
 | |
| 356 | + | |
| 357 | +        for (_, file_path), file_blob in zip(batch.values(), batch_blobs):
 | |
| 358 | +            os.makedirs(os.path.dirname(file_path), exist_ok=True)
 | |
| 359 | + | |
| 360 | +            with open(file_path, 'wb') as byte_file:
 | |
| 361 | +                byte_file.write(file_blob)
 | |
| 362 | + | |
| 363 | +    def _fetch_directory(self, digest, directory_path):
 | |
| 364 | +        """Fetches a file using ByteStream.GetTree()"""
 | |
| 365 | +        # Better fail early if the local root path cannot be created:
 | |
| 366 | +        os.makedirs(directory_path, exist_ok=True)
 | |
| 367 | + | |
| 368 | +        directories = {}
 | |
| 369 | +        directory_fetched = False
 | |
| 370 | +        # First, try GetTree() if not known to be unimplemented yet:
 | |
| 371 | +        if not _CallCache.unimplemented(self.channel, 'GetTree'):
 | |
| 372 | +            tree_request = remote_execution_pb2.GetTreeRequest()
 | |
| 373 | +            tree_request.root_digest.CopyFrom(digest)
 | |
| 374 | +            tree_request.page_size = MAX_REQUEST_COUNT
 | |
| 375 | +            if self.instance_name is not None:
 | |
| 376 | +                tree_request.instance_name = self.instance_name
 | |
| 377 | + | |
| 378 | +            try:
 | |
| 379 | +                for tree_response in self.__cas_stub.GetTree(tree_request):
 | |
| 380 | +                    for directory in tree_response.directories:
 | |
| 381 | +                        directory_blob = directory.SerializeToString()
 | |
| 382 | +                        directory_hash = HASH(directory_blob).hexdigest()
 | |
| 383 | + | |
| 384 | +                        directories[directory_hash] = directory
 | |
| 385 | + | |
| 386 | +                assert digest.hash in directories
 | |
| 387 | + | |
| 388 | +                directory = directories[digest.hash]
 | |
| 389 | +                self._write_directory(digest.hash, directory_path,
 | |
| 390 | +                                      directories=directories, root_barrier=directory_path)
 | |
| 391 | + | |
| 392 | +                directory_fetched = True
 | |
| 393 | + | |
| 394 | +            except grpc.RpcError as e:
 | |
| 395 | +                status_code = e.code()
 | |
| 396 | +                if status_code == grpc.StatusCode.UNIMPLEMENTED:
 | |
| 397 | +                    _CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs')
 | |
| 398 | + | |
| 399 | +                elif status_code == grpc.StatusCode.NOT_FOUND:
 | |
| 400 | +                    raise NotFoundError("Requested directory does not exist on the remote.")
 | |
| 401 | + | |
| 402 | +                else:
 | |
| 403 | +                    assert False
 | |
| 404 | + | |
| 405 | +        # TODO: Try with BatchReadBlobs().
 | |
| 406 | + | |
| 407 | +        # Fallback to Read() if no GetTree():
 | |
| 408 | +        if not directory_fetched:
 | |
| 409 | +            directory = remote_execution_pb2.Directory()
 | |
| 410 | +            directory.ParseFromString(self._fetch_blob(digest))
 | |
| 411 | + | |
| 412 | +            self._write_directory(directory, directory_path,
 | |
| 413 | +                                  root_barrier=directory_path)
 | |
| 414 | + | |
| 415 | +    def _write_directory(self, root_directory, root_path, directories=None, root_barrier=None):
 | |
| 416 | +        """Generates a local directory structure"""
 | |
| 417 | +        for file_node in root_directory.files:
 | |
| 418 | +            file_path = os.path.join(root_path, file_node.name)
 | |
| 419 | + | |
| 420 | +            self._queue_file(file_node.digest, file_path)
 | |
| 421 | + | |
| 422 | +        for directory_node in root_directory.directories:
 | |
| 423 | +            directory_path = os.path.join(root_path, directory_node.name)
 | |
| 424 | +            if directories and directory_node.digest.hash in directories:
 | |
| 425 | +                directory = directories[directory_node.digest.hash]
 | |
| 426 | +            else:
 | |
| 427 | +                directory = remote_execution_pb2.Directory()
 | |
| 428 | +                directory.ParseFromString(self._fetch_blob(directory_node.digest))
 | |
| 429 | + | |
| 430 | +            os.makedirs(directory_path, exist_ok=True)
 | |
| 431 | + | |
| 432 | +            self._write_directory(directory, directory_path,
 | |
| 433 | +                                  directories=directories, root_barrier=root_barrier)
 | |
| 434 | + | |
| 435 | +        for symlink_node in root_directory.symlinks:
 | |
| 436 | +            symlink_path = os.path.join(root_path, symlink_node.name)
 | |
| 437 | +            if not os.path.isabs(symlink_node.target):
 | |
| 438 | +                target_path = os.path.join(root_path, symlink_node.target)
 | |
| 439 | +            else:
 | |
| 440 | +                target_path = symlink_node.target
 | |
| 441 | +            target_path = os.path.normpath(target_path)
 | |
| 442 | + | |
| 443 | +            # Do not create links pointing outside the barrier:
 | |
| 444 | +            if root_barrier is not None:
 | |
| 445 | +                common_path = os.path.commonprefix([root_barrier, target_path])
 | |
| 446 | +                if not common_path.startswith(root_barrier):
 | |
| 447 | +                    continue
 | |
| 448 | + | |
| 449 | +            os.symlink(symlink_path, target_path)
 | |
| 450 | + | |
| 451 | + | |
| 46 | 452 |  @contextmanager
 | 
| 47 | 453 |  def upload(channel, instance=None, u_uid=None):
 | 
| 48 | 454 |      """Context manager generator for the :class:`Uploader` class."""
 | 
| ... | ... | @@ -63,16 +469,8 @@ class Uploader: | 
| 63 | 469 |  | 
| 64 | 470 |          with upload(channel, instance='build') as uploader:
 | 
| 65 | 471 |              uploader.upload_file('/path/to/local/file')
 | 
| 66 | - | |
| 67 | -    Attributes:
 | |
| 68 | -        FILE_SIZE_THRESHOLD (int): maximum size for a queueable file.
 | |
| 69 | -        MAX_REQUEST_SIZE (int): maximum size for a single gRPC request.
 | |
| 70 | 472 |      """
 | 
| 71 | 473 |  | 
| 72 | -    FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
 | |
| 73 | -    MAX_REQUEST_SIZE = 2 * 1024 * 1024
 | |
| 74 | -    MAX_REQUEST_COUNT = 500
 | |
| 75 | - | |
| 76 | 474 |      def __init__(self, channel, instance=None, u_uid=None):
 | 
| 77 | 475 |          """Initializes a new :class:`Uploader` instance.
 | 
| 78 | 476 |  | 
| ... | ... | @@ -115,7 +513,7 @@ class Uploader: | 
| 115 | 513 |          Returns:
 | 
| 116 | 514 |              :obj:`Digest`: the sent blob's digest.
 | 
| 117 | 515 |          """
 | 
| 118 | -        if not queue or len(blob) > Uploader.FILE_SIZE_THRESHOLD:
 | |
| 516 | +        if not queue or len(blob) > FILE_SIZE_THRESHOLD:
 | |
| 119 | 517 |              blob_digest = self._send_blob(blob, digest=digest)
 | 
| 120 | 518 |          else:
 | 
| 121 | 519 |              blob_digest = self._queue_blob(blob, digest=digest)
 | 
| ... | ... | @@ -141,7 +539,7 @@ class Uploader: | 
| 141 | 539 |          """
 | 
| 142 | 540 |          message_blob = message.SerializeToString()
 | 
| 143 | 541 |  | 
| 144 | -        if not queue or len(message_blob) > Uploader.FILE_SIZE_THRESHOLD:
 | |
| 542 | +        if not queue or len(message_blob) > FILE_SIZE_THRESHOLD:
 | |
| 145 | 543 |              message_digest = self._send_blob(message_blob, digest=digest)
 | 
| 146 | 544 |          else:
 | 
| 147 | 545 |              message_digest = self._queue_blob(message_blob, digest=digest)
 | 
| ... | ... | @@ -174,7 +572,7 @@ class Uploader: | 
| 174 | 572 |          with open(file_path, 'rb') as bytes_steam:
 | 
| 175 | 573 |              file_bytes = bytes_steam.read()
 | 
| 176 | 574 |  | 
| 177 | -        if not queue or len(file_bytes) > Uploader.FILE_SIZE_THRESHOLD:
 | |
| 575 | +        if not queue or len(file_bytes) > FILE_SIZE_THRESHOLD:
 | |
| 178 | 576 |              file_digest = self._send_blob(file_bytes)
 | 
| 179 | 577 |          else:
 | 
| 180 | 578 |              file_digest = self._queue_blob(file_bytes)
 | 
| ... | ... | @@ -316,7 +714,7 @@ class Uploader: | 
| 316 | 714 |              finished = False
 | 
| 317 | 715 |              remaining = len(content)
 | 
| 318 | 716 |              while not finished:
 | 
| 319 | -                chunk_size = min(remaining, Uploader.MAX_REQUEST_SIZE)
 | |
| 717 | +                chunk_size = min(remaining, MAX_REQUEST_SIZE)
 | |
| 320 | 718 |                  remaining -= chunk_size
 | 
| 321 | 719 |  | 
| 322 | 720 |                  request = bytestream_pb2.WriteRequest()
 | 
| ... | ... | @@ -347,9 +745,9 @@ class Uploader: | 
| 347 | 745 |              blob_digest.hash = HASH(blob).hexdigest()
 | 
| 348 | 746 |              blob_digest.size_bytes = len(blob)
 | 
| 349 | 747 |  | 
| 350 | -        if self.__request_size + blob_digest.size_bytes > Uploader.MAX_REQUEST_SIZE:
 | |
| 748 | +        if self.__request_size + blob_digest.size_bytes > MAX_REQUEST_SIZE:
 | |
| 351 | 749 |              self.flush()
 | 
| 352 | -        elif self.__request_count >= Uploader.MAX_REQUEST_COUNT:
 | |
| 750 | +        elif self.__request_count >= MAX_REQUEST_COUNT:
 | |
| 353 | 751 |              self.flush()
 | 
| 354 | 752 |  | 
| 355 | 753 |          self.__requests[blob_digest.hash] = (blob, blob_digest)
 | 
| ... | ... | @@ -68,6 +68,18 @@ class ContentAddressableStorageService(remote_execution_pb2_grpc.ContentAddressa | 
| 68 | 68 |  | 
| 69 | 69 |          return remote_execution_pb2.BatchReadBlobsResponse()
 | 
| 70 | 70 |  | 
| 71 | +    def BatchReadBlobs(self, request, context):
 | |
| 72 | +        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
 | |
| 73 | +        context.set_details('Method not implemented!')
 | |
| 74 | + | |
| 75 | +        return remote_execution_pb2.BatchReadBlobsResponse()
 | |
| 76 | + | |
| 77 | +    def GetTree(self, request, context):
 | |
| 78 | +        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
 | |
| 79 | +        context.set_details('Method not implemented!')
 | |
| 80 | + | |
| 81 | +        return iter([remote_execution_pb2.GetTreeResponse()])
 | |
| 82 | + | |
| 71 | 83 |      def _get_instance(self, instance_name):
 | 
| 72 | 84 |          try:
 | 
| 73 | 85 |              return self._instances[instance_name]
 | 
| ... | ... | @@ -171,6 +183,12 @@ class ByteStreamService(bytestream_pb2_grpc.ByteStreamServicer): | 
| 171 | 183 |  | 
| 172 | 184 |          return bytestream_pb2.WriteResponse()
 | 
| 173 | 185 |  | 
| 186 | +    def QueryWriteStatus(self, request, context):
 | |
| 187 | +        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
 | |
| 188 | +        context.set_details('Method not implemented!')
 | |
| 189 | + | |
| 190 | +        return bytestream_pb2.QueryWriteStatusResponse()
 | |
| 191 | + | |
| 174 | 192 |      def _get_instance(self, instance_name):
 | 
| 175 | 193 |          try:
 | 
| 176 | 194 |              return self._instances[instance_name]
 | 
| ... | ... | @@ -23,14 +23,10 @@ Forwwards storage requests to a remote storage. | 
| 23 | 23 |  import io
 | 
| 24 | 24 |  import logging
 | 
| 25 | 25 |  | 
| 26 | -import grpc
 | |
| 27 | - | |
| 28 | -from buildgrid.client.cas import upload
 | |
| 29 | -from buildgrid._protos.google.bytestream import bytestream_pb2_grpc
 | |
| 26 | +from buildgrid.client.cas import download, upload
 | |
| 30 | 27 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
 | 
| 31 | 28 |  from buildgrid._protos.google.rpc import code_pb2
 | 
| 32 | 29 |  from buildgrid._protos.google.rpc import status_pb2
 | 
| 33 | -from buildgrid.utils import gen_fetch_blob
 | |
| 34 | 30 |  from buildgrid.settings import HASH
 | 
| 35 | 31 |  | 
| 36 | 32 |  from .storage_abc import StorageABC
 | 
| ... | ... | @@ -44,7 +40,6 @@ class RemoteStorage(StorageABC): | 
| 44 | 40 |          self.instance_name = instance_name
 | 
| 45 | 41 |          self.channel = channel
 | 
| 46 | 42 |  | 
| 47 | -        self._stub_bs = bytestream_pb2_grpc.ByteStreamStub(channel)
 | |
| 48 | 43 |          self._stub_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(channel)
 | 
| 49 | 44 |  | 
| 50 | 45 |      def has_blob(self, digest):
 | 
| ... | ... | @@ -53,25 +48,12 @@ class RemoteStorage(StorageABC): | 
| 53 | 48 |          return False
 | 
| 54 | 49 |  | 
| 55 | 50 |      def get_blob(self, digest):
 | 
| 56 | -        try:
 | |
| 57 | -            fetched_data = io.BytesIO()
 | |
| 58 | -            length = 0
 | |
| 59 | - | |
| 60 | -            for data in gen_fetch_blob(self._stub_bs, digest, self.instance_name):
 | |
| 61 | -                length += fetched_data.write(data)
 | |
| 62 | - | |
| 63 | -            assert digest.size_bytes == length
 | |
| 64 | -            fetched_data.seek(0)
 | |
| 65 | -            return fetched_data
 | |
| 66 | - | |
| 67 | -        except grpc.RpcError as e:
 | |
| 68 | -            if e.code() == grpc.StatusCode.NOT_FOUND:
 | |
| 69 | -                pass
 | |
| 51 | +        with download(self.channel, instance=self.instance_name) as downloader:
 | |
| 52 | +            blob = downloader.get_blob(digest)
 | |
| 53 | +            if blob is not None:
 | |
| 54 | +                return io.BytesIO(blob)
 | |
| 70 | 55 |              else:
 | 
| 71 | -                self.logger.error(e.details())
 | |
| 72 | -                raise
 | |
| 73 | - | |
| 74 | -        return None
 | |
| 56 | +                return None
 | |
| 75 | 57 |  | 
| 76 | 58 |      def begin_write(self, digest):
 | 
| 77 | 59 |          return io.BytesIO()
 | 
| ... | ... | @@ -18,87 +18,6 @@ import os | 
| 18 | 18 |  | 
| 19 | 19 |  from buildgrid.settings import HASH
 | 
| 20 | 20 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 | 
| 21 | -from buildgrid._protos.google.bytestream import bytestream_pb2
 | |
| 22 | - | |
| 23 | - | |
| 24 | -def gen_fetch_blob(stub, digest, instance_name=""):
 | |
| 25 | -    """ Generates byte stream from a fetch blob request
 | |
| 26 | -    """
 | |
| 27 | - | |
| 28 | -    resource_name = os.path.join(instance_name, 'blobs', digest.hash, str(digest.size_bytes))
 | |
| 29 | -    request = bytestream_pb2.ReadRequest(resource_name=resource_name,
 | |
| 30 | -                                         read_offset=0)
 | |
| 31 | - | |
| 32 | -    for response in stub.Read(request):
 | |
| 33 | -        yield response.data
 | |
| 34 | - | |
| 35 | - | |
| 36 | -def write_fetch_directory(root_directory, stub, digest, instance_name=None):
 | |
| 37 | -    """Locally replicates a directory from CAS.
 | |
| 38 | - | |
| 39 | -    Args:
 | |
| 40 | -        root_directory (str): local directory to populate.
 | |
| 41 | -        stub (): gRPC stub for CAS communication.
 | |
| 42 | -        digest (Digest): digest for the directory to fetch from CAS.
 | |
| 43 | -        instance_name (str, optional): farm instance name to query data from.
 | |
| 44 | -    """
 | |
| 45 | - | |
| 46 | -    if not os.path.isabs(root_directory):
 | |
| 47 | -        root_directory = os.path.abspath(root_directory)
 | |
| 48 | -    if not os.path.exists(root_directory):
 | |
| 49 | -        os.makedirs(root_directory, exist_ok=True)
 | |
| 50 | - | |
| 51 | -    directory = parse_to_pb2_from_fetch(remote_execution_pb2.Directory(),
 | |
| 52 | -                                        stub, digest, instance_name)
 | |
| 53 | - | |
| 54 | -    for directory_node in directory.directories:
 | |
| 55 | -        child_path = os.path.join(root_directory, directory_node.name)
 | |
| 56 | - | |
| 57 | -        write_fetch_directory(child_path, stub, directory_node.digest, instance_name)
 | |
| 58 | - | |
| 59 | -    for file_node in directory.files:
 | |
| 60 | -        child_path = os.path.join(root_directory, file_node.name)
 | |
| 61 | - | |
| 62 | -        with open(child_path, 'wb') as child_file:
 | |
| 63 | -            write_fetch_blob(child_file, stub, file_node.digest, instance_name)
 | |
| 64 | - | |
| 65 | -    for symlink_node in directory.symlinks:
 | |
| 66 | -        child_path = os.path.join(root_directory, symlink_node.name)
 | |
| 67 | - | |
| 68 | -        if os.path.isabs(symlink_node.target):
 | |
| 69 | -            continue  # No out of temp-directory links for now.
 | |
| 70 | -        target_path = os.path.join(root_directory, symlink_node.target)
 | |
| 71 | - | |
| 72 | -        os.symlink(child_path, target_path)
 | |
| 73 | - | |
| 74 | - | |
| 75 | -def write_fetch_blob(target_file, stub, digest, instance_name=None):
 | |
| 76 | -    """Extracts a blob from CAS into a local file.
 | |
| 77 | - | |
| 78 | -    Args:
 | |
| 79 | -        target_file (str): local file to write.
 | |
| 80 | -        stub (): gRPC stub for CAS communication.
 | |
| 81 | -        digest (Digest): digest for the blob to fetch from CAS.
 | |
| 82 | -        instance_name (str, optional): farm instance name to query data from.
 | |
| 83 | -    """
 | |
| 84 | - | |
| 85 | -    for stream in gen_fetch_blob(stub, digest, instance_name):
 | |
| 86 | -        target_file.write(stream)
 | |
| 87 | -    target_file.flush()
 | |
| 88 | - | |
| 89 | -    assert digest.size_bytes == os.fstat(target_file.fileno()).st_size
 | |
| 90 | - | |
| 91 | - | |
| 92 | -def parse_to_pb2_from_fetch(pb2, stub, digest, instance_name=""):
 | |
| 93 | -    """ Fetches stream and parses it into given pb2
 | |
| 94 | -    """
 | |
| 95 | - | |
| 96 | -    stream_bytes = b''
 | |
| 97 | -    for stream in gen_fetch_blob(stub, digest, instance_name):
 | |
| 98 | -        stream_bytes += stream
 | |
| 99 | - | |
| 100 | -    pb2.ParseFromString(stream_bytes)
 | |
| 101 | -    return pb2
 | |
| 102 | 21 |  | 
| 103 | 22 |  | 
| 104 | 23 |  def create_digest(bytes_to_digest):
 | 
| ... | ... | @@ -50,17 +50,31 @@ BuildGrid's Command Line Interface (CLI) reference documentation. | 
| 50 | 50 |  | 
| 51 | 51 |  ----
 | 
| 52 | 52 |  | 
| 53 | +.. _invoking-bgd-cas-download-dir:
 | |
| 54 | + | |
| 55 | +.. click:: buildgrid._app.commands.cmd_cas:download_directory
 | |
| 56 | +   :prog: bgd cas download-dir
 | |
| 57 | + | |
| 58 | +----
 | |
| 59 | + | |
| 60 | +.. _invoking-bgd-cas-download-file:
 | |
| 61 | + | |
| 62 | +.. click:: buildgrid._app.commands.cmd_cas:download_file
 | |
| 63 | +   :prog: bgd cas download-file
 | |
| 64 | + | |
| 65 | +----
 | |
| 66 | + | |
| 53 | 67 |  .. _invoking-bgd-cas-upload-dir:
 | 
| 54 | 68 |  | 
| 55 | -.. click:: buildgrid._app.commands.cmd_cas:upload_dir
 | |
| 69 | +.. click:: buildgrid._app.commands.cmd_cas:upload_directory
 | |
| 56 | 70 |     :prog: bgd cas upload-dir
 | 
| 57 | 71 |  | 
| 58 | 72 |  ----
 | 
| 59 | 73 |  | 
| 60 | -.. _invoking-bgd-cas-upload-files:
 | |
| 74 | +.. _invoking-bgd-cas-upload-file:
 | |
| 61 | 75 |  | 
| 62 | -.. click:: buildgrid._app.commands.cmd_cas:upload_files
 | |
| 63 | -   :prog: bgd cas upload-files
 | |
| 76 | +.. click:: buildgrid._app.commands.cmd_cas:upload_file
 | |
| 77 | +   :prog: bgd cas upload-file
 | |
| 64 | 78 |  | 
| 65 | 79 |  ----
 | 
| 66 | 80 |  | 
| ... | ... | @@ -73,7 +87,7 @@ BuildGrid's Command Line Interface (CLI) reference documentation. | 
| 73 | 87 |  | 
| 74 | 88 |  .. _invoking-bgd-execute-command:
 | 
| 75 | 89 |  | 
| 76 | -.. click:: buildgrid._app.commands.cmd_execute:command
 | |
| 90 | +.. click:: buildgrid._app.commands.cmd_execute:run_command
 | |
| 77 | 91 |     :prog: bgd execute command
 | 
| 78 | 92 |  | 
| 79 | 93 |  ----
 | 
| ... | ... | @@ -112,6 +126,7 @@ BuildGrid's Command Line Interface (CLI) reference documentation. | 
| 112 | 126 |     :prog: bgd operation wait
 | 
| 113 | 127 |  | 
| 114 | 128 |  ----
 | 
| 129 | + | |
| 115 | 130 |  .. _invoking-bgd-server:
 | 
| 116 | 131 |  | 
| 117 | 132 |  .. click:: buildgrid._app.commands.cmd_server:cli
 | 
| ... | ... | @@ -14,12 +14,15 @@ | 
| 14 | 14 |  | 
| 15 | 15 |  # pylint: disable=redefined-outer-name
 | 
| 16 | 16 |  | 
| 17 | + | |
| 18 | +from copy import deepcopy
 | |
| 17 | 19 |  import os
 | 
| 20 | +import tempfile
 | |
| 18 | 21 |  | 
| 19 | 22 |  import grpc
 | 
| 20 | 23 |  import pytest
 | 
| 21 | 24 |  | 
| 22 | -from buildgrid.client.cas import upload
 | |
| 25 | +from buildgrid.client.cas import download, upload
 | |
| 23 | 26 |  from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 | 
| 24 | 27 |  from buildgrid.utils import create_digest
 | 
| 25 | 28 |  | 
| ... | ... | @@ -41,6 +44,8 @@ FILES = [ | 
| 41 | 44 |      (os.path.join(DATA_DIR, 'hello.cc'),),
 | 
| 42 | 45 |      (os.path.join(DATA_DIR, 'hello', 'hello.c'),
 | 
| 43 | 46 |       os.path.join(DATA_DIR, 'hello', 'hello.h'))]
 | 
| 47 | +FOLDERS = [
 | |
| 48 | +    (os.path.join(DATA_DIR, 'hello'),)]
 | |
| 44 | 49 |  DIRECTORIES = [
 | 
| 45 | 50 |      (os.path.join(DATA_DIR, 'hello'),),
 | 
| 46 | 51 |      (os.path.join(DATA_DIR, 'hello'), DATA_DIR)]
 | 
| ... | ... | @@ -214,3 +219,145 @@ def test_upload_tree(instance, directory_paths): | 
| 214 | 219 |              directory_digest = create_digest(tree.root.SerializeToString())
 | 
| 215 | 220 |  | 
| 216 | 221 |              assert server.compare_directories(directory_digest, directory_path)
 | 
| 222 | + | |
| 223 | + | |
| 224 | +@pytest.mark.parametrize('blobs', BLOBS)
 | |
| 225 | +@pytest.mark.parametrize('instance', INTANCES)
 | |
| 226 | +def test_download_blob(instance, blobs):
 | |
| 227 | +    # Actual test function, to be run in a subprocess:
 | |
| 228 | +    def __test_download_blob(queue, remote, instance, digests):
 | |
| 229 | +        # Open a channel to the remote CAS server:
 | |
| 230 | +        channel = grpc.insecure_channel(remote)
 | |
| 231 | + | |
| 232 | +        blobs = []
 | |
| 233 | +        with download(channel, instance) as downloader:
 | |
| 234 | +            if len(digests) > 1:
 | |
| 235 | +                blobs.extend(downloader.get_blobs(digests))
 | |
| 236 | +            else:
 | |
| 237 | +                blobs.append(downloader.get_blob(digests[0]))
 | |
| 238 | + | |
| 239 | +        queue.put(blobs)
 | |
| 240 | + | |
| 241 | +    # Start a minimal CAS server in a subprocess:
 | |
| 242 | +    with serve_cas([instance]) as server:
 | |
| 243 | +        digests = []
 | |
| 244 | +        for blob in blobs:
 | |
| 245 | +            digest = server.store_blob(blob)
 | |
| 246 | +            digests.append(digest)
 | |
| 247 | + | |
| 248 | +        blobs = run_in_subprocess(__test_download_blob,
 | |
| 249 | +                                  server.remote, instance, digests)
 | |
| 250 | + | |
| 251 | +        for digest, blob in zip(digests, blobs):
 | |
| 252 | +            assert server.compare_blobs(digest, blob)
 | |
| 253 | + | |
| 254 | + | |
| 255 | +@pytest.mark.parametrize('messages', MESSAGES)
 | |
| 256 | +@pytest.mark.parametrize('instance', INTANCES)
 | |
| 257 | +def test_download_message(instance, messages):
 | |
| 258 | +    # Actual test function, to be run in a subprocess:
 | |
| 259 | +    def __test_download_message(queue, remote, instance, digests, empty_messages):
 | |
| 260 | +        # Open a channel to the remote CAS server:
 | |
| 261 | +        channel = grpc.insecure_channel(remote)
 | |
| 262 | + | |
| 263 | +        messages = []
 | |
| 264 | +        with download(channel, instance) as downloader:
 | |
| 265 | +            if len(digests) > 1:
 | |
| 266 | +                messages = downloader.get_messages(digests, empty_messages)
 | |
| 267 | +                messages = list([m.SerializeToString() for m in messages])
 | |
| 268 | +            else:
 | |
| 269 | +                message = downloader.get_message(digests[0], empty_messages[0])
 | |
| 270 | +                messages.append(message.SerializeToString())
 | |
| 271 | + | |
| 272 | +        queue.put(messages)
 | |
| 273 | + | |
| 274 | +    # Start a minimal CAS server in a subprocess:
 | |
| 275 | +    with serve_cas([instance]) as server:
 | |
| 276 | +        empty_messages, digests = [], []
 | |
| 277 | +        for message in messages:
 | |
| 278 | +            digest = server.store_message(message)
 | |
| 279 | +            digests.append(digest)
 | |
| 280 | + | |
| 281 | +            empty_message = deepcopy(message)
 | |
| 282 | +            empty_message.Clear()
 | |
| 283 | +            empty_messages.append(empty_message)
 | |
| 284 | + | |
| 285 | +        messages = run_in_subprocess(__test_download_message,
 | |
| 286 | +                                     server.remote, instance, digests, empty_messages)
 | |
| 287 | + | |
| 288 | +        for digest, message_blob, message in zip(digests, messages, empty_messages):
 | |
| 289 | +            message.ParseFromString(message_blob)
 | |
| 290 | + | |
| 291 | +            assert server.compare_messages(digest, message)
 | |
| 292 | + | |
| 293 | + | |
| 294 | +@pytest.mark.parametrize('file_paths', FILES)
 | |
| 295 | +@pytest.mark.parametrize('instance', INTANCES)
 | |
| 296 | +def test_download_file(instance, file_paths):
 | |
| 297 | +    # Actual test function, to be run in a subprocess:
 | |
| 298 | +    def __test_download_file(queue, remote, instance, digests, paths):
 | |
| 299 | +        # Open a channel to the remote CAS server:
 | |
| 300 | +        channel = grpc.insecure_channel(remote)
 | |
| 301 | + | |
| 302 | +        with download(channel, instance) as downloader:
 | |
| 303 | +            if len(digests) > 1:
 | |
| 304 | +                for digest, path in zip(digests, paths):
 | |
| 305 | +                    downloader.download_file(digest, path, queue=False)
 | |
| 306 | +            else:
 | |
| 307 | +                downloader.download_file(digests[0], paths[0], queue=False)
 | |
| 308 | + | |
| 309 | +        queue.put(None)
 | |
| 310 | + | |
| 311 | +    # Start a minimal CAS server in a subprocess:
 | |
| 312 | +    with serve_cas([instance]) as server:
 | |
| 313 | +        with tempfile.TemporaryDirectory() as temp_folder:
 | |
| 314 | +            paths, digests = [], []
 | |
| 315 | +            for file_path in file_paths:
 | |
| 316 | +                digest = server.store_file(file_path)
 | |
| 317 | +                digests.append(digest)
 | |
| 318 | + | |
| 319 | +                path = os.path.relpath(file_path, start=DATA_DIR)
 | |
| 320 | +                path = os.path.join(temp_folder, path)
 | |
| 321 | +                paths.append(path)
 | |
| 322 | + | |
| 323 | +                run_in_subprocess(__test_download_file,
 | |
| 324 | +                                  server.remote, instance, digests, paths)
 | |
| 325 | + | |
| 326 | +            for digest, path in zip(digests, paths):
 | |
| 327 | +                assert server.compare_files(digest, path)
 | |
| 328 | + | |
| 329 | + | |
| 330 | +@pytest.mark.parametrize('folder_paths', FOLDERS)
 | |
| 331 | +@pytest.mark.parametrize('instance', INTANCES)
 | |
| 332 | +def test_download_directory(instance, folder_paths):
 | |
| 333 | +    # Actual test function, to be run in a subprocess:
 | |
| 334 | +    def __test_download_directory(queue, remote, instance, digests, paths):
 | |
| 335 | +        # Open a channel to the remote CAS server:
 | |
| 336 | +        channel = grpc.insecure_channel(remote)
 | |
| 337 | + | |
| 338 | +        with download(channel, instance) as downloader:
 | |
| 339 | +            if len(digests) > 1:
 | |
| 340 | +                for digest, path in zip(digests, paths):
 | |
| 341 | +                    downloader.download_directory(digest, path)
 | |
| 342 | +            else:
 | |
| 343 | +                downloader.download_directory(digests[0], paths[0])
 | |
| 344 | + | |
| 345 | +        queue.put(None)
 | |
| 346 | + | |
| 347 | +    # Start a minimal CAS server in a subprocess:
 | |
| 348 | +    with serve_cas([instance]) as server:
 | |
| 349 | +        with tempfile.TemporaryDirectory() as temp_folder:
 | |
| 350 | +            paths, digests = [], []
 | |
| 351 | +            for folder_path in folder_paths:
 | |
| 352 | +                digest = server.store_folder(folder_path)
 | |
| 353 | +                digests.append(digest)
 | |
| 354 | + | |
| 355 | +                path = os.path.relpath(folder_path, start=DATA_DIR)
 | |
| 356 | +                path = os.path.join(temp_folder, path)
 | |
| 357 | +                paths.append(path)
 | |
| 358 | + | |
| 359 | +                run_in_subprocess(__test_download_directory,
 | |
| 360 | +                                  server.remote, instance, digests, paths)
 | |
| 361 | + | |
| 362 | +            for digest, path in zip(digests, paths):
 | |
| 363 | +                assert server.compare_directories(digest, path) | 
| ... | ... | @@ -30,6 +30,7 @@ from buildgrid.server.cas.service import ContentAddressableStorageService | 
| 30 | 30 |  from buildgrid.server.cas.instance import ByteStreamInstance
 | 
| 31 | 31 |  from buildgrid.server.cas.instance import ContentAddressableStorageInstance
 | 
| 32 | 32 |  from buildgrid.server.cas.storage.disk import DiskStorage
 | 
| 33 | +from buildgrid.utils import create_digest, merkle_tree_maker
 | |
| 33 | 34 |  | 
| 34 | 35 |  | 
| 35 | 36 |  @contextmanager
 | 
| ... | ... | @@ -124,6 +125,15 @@ class Server: | 
| 124 | 125 |      def get(self, digest):
 | 
| 125 | 126 |          return self.__storage.get_blob(digest).read()
 | 
| 126 | 127 |  | 
| 128 | +    def store_blob(self, blob):
 | |
| 129 | +        digest = create_digest(blob)
 | |
| 130 | +        write_buffer = self.__storage.begin_write(digest)
 | |
| 131 | +        write_buffer.write(blob)
 | |
| 132 | + | |
| 133 | +        self.__storage.commit_write(digest, write_buffer)
 | |
| 134 | + | |
| 135 | +        return digest
 | |
| 136 | + | |
| 127 | 137 |      def compare_blobs(self, digest, blob):
 | 
| 128 | 138 |          if not self.__storage.has_blob(digest):
 | 
| 129 | 139 |              return False
 | 
| ... | ... | @@ -133,6 +143,16 @@ class Server: | 
| 133 | 143 |  | 
| 134 | 144 |          return blob == stored_blob
 | 
| 135 | 145 |  | 
| 146 | +    def store_message(self, message):
 | |
| 147 | +        message_blob = message.SerializeToString()
 | |
| 148 | +        message_digest = create_digest(message_blob)
 | |
| 149 | +        write_buffer = self.__storage.begin_write(message_digest)
 | |
| 150 | +        write_buffer.write(message_blob)
 | |
| 151 | + | |
| 152 | +        self.__storage.commit_write(message_digest, write_buffer)
 | |
| 153 | + | |
| 154 | +        return message_digest
 | |
| 155 | + | |
| 136 | 156 |      def compare_messages(self, digest, message):
 | 
| 137 | 157 |          if not self.__storage.has_blob(digest):
 | 
| 138 | 158 |              return False
 | 
| ... | ... | @@ -144,6 +164,17 @@ class Server: | 
| 144 | 164 |  | 
| 145 | 165 |          return message_blob == stored_blob
 | 
| 146 | 166 |  | 
| 167 | +    def store_file(self, file_path):
 | |
| 168 | +        with open(file_path, 'rb') as file_bytes:
 | |
| 169 | +            file_blob = file_bytes.read()
 | |
| 170 | +        file_digest = create_digest(file_blob)
 | |
| 171 | +        write_buffer = self.__storage.begin_write(file_digest)
 | |
| 172 | +        write_buffer.write(file_blob)
 | |
| 173 | + | |
| 174 | +        self.__storage.commit_write(file_digest, write_buffer)
 | |
| 175 | + | |
| 176 | +        return file_digest
 | |
| 177 | + | |
| 147 | 178 |      def compare_files(self, digest, file_path):
 | 
| 148 | 179 |          if not self.__storage.has_blob(digest):
 | 
| 149 | 180 |              return False
 | 
| ... | ... | @@ -156,6 +187,17 @@ class Server: | 
| 156 | 187 |  | 
| 157 | 188 |          return file_blob == stored_blob
 | 
| 158 | 189 |  | 
| 190 | +    def store_folder(self, folder_path):
 | |
| 191 | +        last_digest = None
 | |
| 192 | +        for node, blob, _ in merkle_tree_maker(folder_path):
 | |
| 193 | +            write_buffer = self.__storage.begin_write(node.digest)
 | |
| 194 | +            write_buffer.write(blob)
 | |
| 195 | + | |
| 196 | +            self.__storage.commit_write(node.digest, write_buffer)
 | |
| 197 | +            last_digest = node.digest
 | |
| 198 | + | |
| 199 | +        return last_digest
 | |
| 200 | + | |
| 159 | 201 |      def compare_directories(self, digest, directory_path):
 | 
| 160 | 202 |          if not self.__storage.has_blob(digest):
 | 
| 161 | 203 |              return False
 | 
