Skip to content

Commit

Permalink
Fixed serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
martindevans committed Oct 20, 2023
1 parent 768747c commit f621ec6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
22 changes: 11 additions & 11 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ public override int GetHashCode()
public sealed class TensorSplitsCollection
: IEnumerable<float>
{
private readonly float[] _splits = new float[NativeApi.llama_max_devices()];
internal readonly float[] Splits = new float[NativeApi.llama_max_devices()];

/// <summary>
/// The size of this array
/// </summary>
public int Length => _splits.Length;
public int Length => Splits.Length;

/// <summary>
/// Get or set the proportion of work to do on the given device.
Expand All @@ -123,8 +123,8 @@ public sealed class TensorSplitsCollection
/// <returns></returns>
public float this[int index]
{
get => _splits[index];
set => _splits[index] = value;
get => Splits[index];
set => Splits[index] = value;
}

/// <summary>
Expand All @@ -134,9 +134,9 @@ public float this[int index]
/// <exception cref="ArgumentException"></exception>
public TensorSplitsCollection(float[] splits)
{
if (splits.Length != _splits.Length)
throw new ArgumentException($"tensor splits length must equal {_splits.Length}");
_splits = splits;
if (splits.Length != Splits.Length)
throw new ArgumentException($"tensor splits length must equal {Splits.Length}");
Splits = splits;
}

/// <summary>
Expand All @@ -151,25 +151,25 @@ public TensorSplitsCollection()
/// </summary>
public void Clear()
{
Array.Clear(_splits, 0, _splits.Length);
Array.Clear(Splits, 0, Splits.Length);
}

internal MemoryHandle Pin()
{
return _splits.AsMemory().Pin();
return Splits.AsMemory().Pin();
}

#region IEnumerator
/// <inheritdoc />
public IEnumerator<float> GetEnumerator()
{
return ((IEnumerable<float>)_splits).GetEnumerator();
return ((IEnumerable<float>)Splits).GetEnumerator();
}

/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator()
{
return _splits.GetEnumerator();
return Splits.GetEnumerator();
}
#endregion
}
Expand Down
2 changes: 1 addition & 1 deletion LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ internal class TensorSplitsCollectionConverter

public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Data, options);
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

0 comments on commit f621ec6

Please sign in to comment.