diff --git a/Library/Homebrew/download_strategy.rb b/Library/Homebrew/download_strategy.rb index b50ae317c7..b0e31e025b 100644 --- a/Library/Homebrew/download_strategy.rb +++ b/Library/Homebrew/download_strategy.rb @@ -594,6 +594,70 @@ class GitHubPrivateRepositoryReleaseDownloadStrategy < GitHubPrivateRepositoryDo end end +# ScpDownloadStrategy downloads files using ssh via scp. To use it, add +# ":using => ScpDownloadStrategy" to the URL section of your formula or +# provide a URL starting with scp://. This strategy uses ssh credentials for +# authentication. If a public/private keypair is configured, it will not +# prompt for a password. +# +# Usage: +# +# class Abc < Formula +# url "scp://example.com/src/abc.1.0.tar.gz" +# ... +class ScpDownloadStrategy < AbstractFileDownloadStrategy + attr_reader :tarball_path, :temporary_path + + def initialize(name, resource) + super + @tarball_path = HOMEBREW_CACHE/"#{name}-#{version}#{ext}" + @temporary_path = Pathname.new("#{cached_location}.incomplete") + parse_url_pattern + end + + def parse_url_pattern + url_pattern = %r{scp://([^@]+@)?([^@:/]+)(:\d+)?/(\S+)} + if @url !~ url_pattern + raise ScpDownloadStrategyError, "Invalid URL for scp: #{@url}" + end + + _, @user, @host, @port, @path = *@url.match(url_pattern) + end + + def fetch + ohai "Downloading #{@url}" + + if cached_location.exist? + puts "Already downloaded: #{cached_location}" + else + begin + safe_system "scp", scp_source, temporary_path.to_s + rescue ErrorDuringExecution + raise ScpDownloadStrategyError, "Failed to run scp #{scp_source}" + end + + ignore_interrupts { temporary_path.rename(cached_location) } + end + end + + def cached_location + tarball_path + end + + def clear_cache + super + rm_rf(temporary_path) + end + + private + + def scp_source + path_prefix = "/" unless @path.start_with?("~") + port_arg = "-P #{@port[1..-1]} " if @port + "#{port_arg}#{@user}#{@host}:#{path_prefix}#{@path}" + end +end + class SubversionDownloadStrategy < VCSDownloadStrategy def initialize(name, resource) super @@ -1140,6 +1204,8 @@ class DownloadStrategyDetector when %r{^s3://} require_aws_sdk S3DownloadStrategy + when %r{^scp://} + ScpDownloadStrategy else CurlDownloadStrategy end diff --git a/Library/Homebrew/exceptions.rb b/Library/Homebrew/exceptions.rb index b3acf821ec..a33ed47abf 100644 --- a/Library/Homebrew/exceptions.rb +++ b/Library/Homebrew/exceptions.rb @@ -515,6 +515,13 @@ class CurlDownloadStrategyError < RuntimeError end end +# raised in ScpDownloadStrategy.fetch +class ScpDownloadStrategyError < RuntimeError + def initialize(cause) + super "Download failed: #{cause}" + end +end + # raised by safe_system in utils.rb class ErrorDuringExecution < RuntimeError def initialize(cmd, args = []) diff --git a/Library/Homebrew/test/download_strategies_spec.rb b/Library/Homebrew/test/download_strategies_spec.rb index 08bbb58e89..f4857787ef 100644 --- a/Library/Homebrew/test/download_strategies_spec.rb +++ b/Library/Homebrew/test/download_strategies_spec.rb @@ -268,6 +268,90 @@ describe CurlDownloadStrategy do end end +describe ScpDownloadStrategy do + def resource_for(url) + double(Resource, url: url, mirrors: [], specs: {}, version: nil) + end + + subject { described_class.new(name, resource) } + let(:name) { "foo" } + let(:url) { "scp://example.com/foo.tar.gz" } + let(:resource) { resource_for(url) } + + describe "#initialize" do + invalid_urls = %w[ + http://example.com/foo.tar.gz + scp://@example.com/foo.tar.gz + scp://example.com:/foo.tar.gz + scp://example.com + ] + + invalid_urls.each do |invalid_url| + context "with invalid URL #{invalid_url}" do + it "raises ScpDownloadStrategyError" do + expect { + described_class.new(name, resource_for(invalid_url)) + }.to raise_error(ScpDownloadStrategyError) + end + end + end + end + + describe "#fetch" do + before do + expect(subject.temporary_path).to receive(:rename).and_return(true) + end + + context "when given a valid URL" do + let(:url) { "scp://example.com/foo.tar.gz" } + it "copies the file via scp" do + expect(subject) + .to receive(:safe_system) + .with("scp", "example.com:/foo.tar.gz", anything) + .and_return(true) + + subject.fetch + end + end + + context "when given a URL with a username" do + let(:url) { "scp://user@example.com/foo.tar.gz" } + it "copies the file via scp" do + expect(subject) + .to receive(:safe_system) + .with("scp", "user@example.com:/foo.tar.gz", anything) + .and_return(true) + + subject.fetch + end + end + + context "when given a URL with a port" do + let(:url) { "scp://example.com:1234/foo.tar.gz" } + it "copies the file via scp" do + expect(subject) + .to receive(:safe_system) + .with("scp", "-P 1234 example.com:/foo.tar.gz", anything) + .and_return(true) + + subject.fetch + end + end + + context "when given a URL with /~/" do + let(:url) { "scp://example.com/~/foo.tar.gz" } + it "treats the path as relative to the home directory" do + expect(subject) + .to receive(:safe_system) + .with("scp", "example.com:~/foo.tar.gz", anything) + .and_return(true) + + subject.fetch + end + end + end +end + describe DownloadStrategyDetector do describe "::detect" do subject { described_class.detect(url, strategy) } @@ -306,6 +390,11 @@ describe DownloadStrategyDetector do end end + context "when given an scp URL" do + let(:url) { "scp://example.com/brew.tar.gz" } + it { is_expected.to eq(ScpDownloadStrategy) } + end + it "defaults to cURL" do expect(subject).to eq(CurlDownloadStrategy) end