Skip to content

Commit

Permalink
Merge pull request #6 from pin/misc-stuff
Browse files Browse the repository at this point in the history
Misc stuff from George
  • Loading branch information
pin committed Dec 13, 2015
2 parents ab5b5e0 + 4056ba4 commit 8b1a82f
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 24 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (c Client) Put(filename string, mode string, handler func(w *io.PipeWriter)
handler(writer)
wg.Done()
}()
s.Run(false)
s.run(false)
wg.Wait()
return nil
}
Expand All @@ -105,7 +105,7 @@ func (c Client) Get(filename string, mode string, handler func(r *io.PipeReader)
handler(reader)
wg.Done()
}()
r.Run(false)
r.run(false)
wg.Wait()
return fmt.Errorf("Send timeout")
}
7 changes: 3 additions & 4 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (p *ERROR) Pack() []byte {
return buffer.Bytes()
}

func ParsePacket(data []byte) (*Packet, error) {
func ParsePacket(data []byte) (Packet, error) {
var p Packet
opcode := binary.BigEndian.Uint16(data)
switch opcode {
Expand All @@ -168,8 +168,7 @@ func ParsePacket(data []byte) (*Packet, error) {
case OP_ERROR:
p = &ERROR{}
default:
return nil, fmt.Errorf("Unknown packet type: %d", opcode)
return nil, fmt.Errorf("unknown opcode: %d", opcode)
}
pp := Packet(p)
return &pp, pp.Unpack(data)
return p, p.Unpack(data)
}
13 changes: 8 additions & 5 deletions receiver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tftp

import (
"errors"
"fmt"
"io"
"log"
Expand All @@ -17,14 +18,16 @@ type receiver struct {
log *log.Logger
}

func (r *receiver) Run(isServerMode bool) error {
var ErrReceiveTimeout = errors.New("receive timeout")

func (r *receiver) run(serverMode bool) error {
var blockNumber uint16
blockNumber = 1
var buffer []byte
buffer = make([]byte, MAX_DATAGRAM_SIZE)
firstBlock := true
for {
last, e := r.receiveBlock(buffer, blockNumber, firstBlock && !isServerMode)
last, e := r.receiveBlock(buffer, blockNumber, firstBlock && !serverMode)
if e != nil {
if r.log != nil {
r.log.Printf("Error receiving block %d: %v", blockNumber, e)
Expand Down Expand Up @@ -69,7 +72,7 @@ func (r *receiver) receiveBlock(b []byte, n uint16, firstBlockOnClient bool) (la
if e != nil {
continue
}
switch p := Packet(*packet).(type) {
switch p := packet.(type) {
case *DATA:
r.log.Printf("got DATA #%d (%d bytes)", p.BlockNumber, len(p.Data))
if n == p.BlockNumber {
Expand All @@ -90,7 +93,7 @@ func (r *receiver) receiveBlock(b []byte, n uint16, firstBlockOnClient bool) (la
}
}
}
return false, fmt.Errorf("Receive timeout")
return false, ErrReceiveTimeout
}

func (r *receiver) terminate(b []byte, n uint16, dallying bool) (e error) {
Expand All @@ -117,7 +120,7 @@ func (r *receiver) terminate(b []byte, n uint16, dallying bool) (e error) {
if e != nil {
continue
}
switch p := Packet(*packet).(type) {
switch p := packet.(type) {
case *DATA:
r.log.Printf("got DATA #%d (%d bytes)", p.BlockNumber, len(p.Data))
if n == p.BlockNumber {
Expand Down
23 changes: 13 additions & 10 deletions sender.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tftp

import (
"errors"
"fmt"
"io"
"log"
Expand All @@ -17,15 +18,17 @@ type sender struct {
log *log.Logger
}

func (s *sender) Run(isServerMode bool) {
var ErrSendTimeout = errors.New("send timeout")

func (s *sender) run(serverMode bool) {
var buffer, tmp []byte
buffer = make([]byte, BLOCK_SIZE)
tmp = make([]byte, MAX_DATAGRAM_SIZE)
if !isServerMode {
e := s.sendRequest(tmp)
if e != nil {
s.log.Printf("Error starting transmission: %v", e)
s.reader.CloseWithError(e)
if !serverMode {
err := s.sendRequest(tmp)
if err != nil {
s.log.Printf("Error starting transmission: %v", err)
s.reader.CloseWithError(err)
return
}
}
Expand Down Expand Up @@ -93,7 +96,7 @@ func (s *sender) sendRequest(tmp []byte) (e error) {
if e != nil {
continue
}
switch p := Packet(*packet).(type) {
switch p := packet.(type) {
case *ACK:
if p.BlockNumber == 0 {
s.log.Printf("got ACK #0")
Expand All @@ -105,7 +108,7 @@ func (s *sender) sendRequest(tmp []byte) (e error) {
}
}
}
return fmt.Errorf("Send timeout")
return ErrSendTimeout
}

func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) {
Expand All @@ -128,7 +131,7 @@ func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) {
if e != nil {
continue
}
switch p := Packet(*packet).(type) {
switch p := packet.(type) {
case *ACK:
s.log.Printf("got ACK #%d", p.BlockNumber)
if n == p.BlockNumber {
Expand All @@ -139,5 +142,5 @@ func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) {
}
}
}
return fmt.Errorf("Send timeout")
return ErrSendTimeout
}
6 changes: 3 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s Server) processRequest(conn *net.UDPConn) error {
if e != nil {
return nil
}
switch p := Packet(*p).(type) {
switch p := p.(type) {
case *WRQ:
s.Log.Printf("got WRQ (filename=%s, mode=%s)", p.Filename, p.Mode)
trasnmissionConn, e := s.transmissionConn()
Expand All @@ -102,7 +102,7 @@ func (s Server) processRequest(conn *net.UDPConn) error {
s.Log.Printf("sent ERROR (code=%d): %s", 1, e.Error())
return e
}
go r.Run(true)
go r.run(true)
case *RRQ:
s.Log.Printf("got RRQ (filename=%s, mode=%s)", p.Filename, p.Mode)
trasnmissionConn, e := s.transmissionConn()
Expand All @@ -112,7 +112,7 @@ func (s Server) processRequest(conn *net.UDPConn) error {
reader, writer := io.Pipe()
r := &sender{remoteAddr, trasnmissionConn, reader, p.Filename, p.Mode, s.Log}
go s.WriteHandler(p.Filename, writer)
go r.Run(true)
go r.run(true)
}
return nil
}
Expand Down
49 changes: 49 additions & 0 deletions tftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,55 @@ func TestPutGet(t *testing.T) {
}
}

func TestTimeout(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "localhost:12322")

log := log.New(os.Stderr, "", log.Ldate|log.Ltime)

writeHandler := func(filename string, r *io.PipeReader) {
buf := make([]byte, 64)
for i := 0; i < 5; i++ {
_, err := r.Read(buf)
if err != nil {
panic(err)
}
}
// server "fail" during receive
}

readHandler := func(filename string, w *io.PipeWriter) {
for i := 0; i < 5; i++ {
_, err := w.Write(randomByteArray(64))
if err != nil {
panic(err)
}
}
// server "fail" during send
}

s = &Server{addr, writeHandler, readHandler, log}
go s.Serve()

c = &Client{addr, log}

var err error
c.Put("test", "octet", func(writer *io.PipeWriter) {
_, err = writer.Write(randomByteArray(5000))
writer.Close()
})
if err != ErrSendTimeout {
t.Fatalf("Send timeout expected, got %v", err)
}

buf := new(bytes.Buffer)
c.Get("test", "octet", func(reader *io.PipeReader) {
_, err = buf.ReadFrom(reader)
})
if err != ErrReceiveTimeout {
t.Fatalf("Receive timeout expected, got %v", err)
}
}

func randomByteArray(n int) []byte {
bs := make([]byte, n)
for i := 0; i < n; i++ {
Expand Down

0 comments on commit 8b1a82f

Please sign in to comment.